Ashoka74 commited on
Commit
947db12
ยท
verified ยท
1 Parent(s): 3b7d18f

Upload 13 files

Browse files
mvadapter/__init__.py ADDED
File without changes
mvadapter/loaders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .custom_adapter import CustomAdapterMixin
mvadapter/loaders/custom_adapter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Optional, Union
3
+
4
+ import safetensors
5
+ import torch
6
+ from diffusers.utils import _get_model_file, logging
7
+ from safetensors import safe_open
8
+
9
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
10
+
11
+
12
+ class CustomAdapterMixin:
13
+ def init_custom_adapter(self, *args, **kwargs):
14
+ self._init_custom_adapter(*args, **kwargs)
15
+
16
+ def _init_custom_adapter(self, *args, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def load_custom_adapter(
20
+ self,
21
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
22
+ weight_name: str,
23
+ subfolder: Optional[str] = None,
24
+ **kwargs,
25
+ ):
26
+ # Load the main state dict first.
27
+ cache_dir = kwargs.pop("cache_dir", None)
28
+ force_download = kwargs.pop("force_download", False)
29
+ proxies = kwargs.pop("proxies", None)
30
+ local_files_only = kwargs.pop("local_files_only", None)
31
+ token = kwargs.pop("token", None)
32
+ revision = kwargs.pop("revision", None)
33
+
34
+ user_agent = {
35
+ "file_type": "attn_procs_weights",
36
+ "framework": "pytorch",
37
+ }
38
+
39
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
40
+ model_file = _get_model_file(
41
+ pretrained_model_name_or_path_or_dict,
42
+ weights_name=weight_name,
43
+ subfolder=subfolder,
44
+ cache_dir=cache_dir,
45
+ force_download=force_download,
46
+ proxies=proxies,
47
+ local_files_only=local_files_only,
48
+ token=token,
49
+ revision=revision,
50
+ user_agent=user_agent,
51
+ )
52
+ if weight_name.endswith(".safetensors"):
53
+ state_dict = {}
54
+ with safe_open(model_file, framework="pt", device="cpu") as f:
55
+ for key in f.keys():
56
+ state_dict[key] = f.get_tensor(key)
57
+ else:
58
+ state_dict = torch.load(model_file, map_location="cpu")
59
+ else:
60
+ state_dict = pretrained_model_name_or_path_or_dict
61
+
62
+ self._load_custom_adapter(state_dict)
63
+
64
+ def _load_custom_adapter(self, state_dict):
65
+ raise NotImplementedError
66
+
67
+ def save_custom_adapter(
68
+ self,
69
+ save_directory: Union[str, os.PathLike],
70
+ weight_name: str,
71
+ safe_serialization: bool = False,
72
+ **kwargs,
73
+ ):
74
+ if os.path.isfile(save_directory):
75
+ logger.error(
76
+ f"Provided path ({save_directory}) should be a directory, not a file"
77
+ )
78
+ return
79
+
80
+ if safe_serialization:
81
+
82
+ def save_function(weights, filename):
83
+ return safetensors.torch.save_file(
84
+ weights, filename, metadata={"format": "pt"}
85
+ )
86
+
87
+ else:
88
+ save_function = torch.save
89
+
90
+ # Save the model
91
+ state_dict = self._save_custom_adapter(**kwargs)
92
+ save_function(state_dict, os.path.join(save_directory, weight_name))
93
+ logger.info(
94
+ f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}"
95
+ )
96
+
97
+ def _save_custom_adapter(self):
98
+ raise NotImplementedError
mvadapter/models/__init__.py ADDED
File without changes
mvadapter/models/attention_processor.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from diffusers.models.attention_processor import Attention
7
+ from diffusers.models.unets import UNet2DConditionModel
8
+ from diffusers.utils import deprecate, logging
9
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
10
+ from einops import rearrange
11
+ from torch import nn
12
+
13
+
14
+ def default_set_attn_proc_func(
15
+ name: str,
16
+ hidden_size: int,
17
+ cross_attention_dim: Optional[int],
18
+ ori_attn_proc: object,
19
+ ) -> object:
20
+ return ori_attn_proc
21
+
22
+
23
+ def set_unet_2d_condition_attn_processor(
24
+ unet: UNet2DConditionModel,
25
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
26
+ set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
27
+ set_custom_attn_proc_func: Callable = default_set_attn_proc_func,
28
+ set_self_attn_module_names: Optional[List[str]] = None,
29
+ set_cross_attn_module_names: Optional[List[str]] = None,
30
+ set_custom_attn_module_names: Optional[List[str]] = None,
31
+ ) -> None:
32
+ do_set_processor = lambda name, module_names: (
33
+ any([name.startswith(module_name) for module_name in module_names])
34
+ if module_names is not None
35
+ else True
36
+ ) # prefix match
37
+
38
+ attn_procs = {}
39
+ for name, attn_processor in unet.attn_processors.items():
40
+ # set attn_processor by default, if module_names is None
41
+ set_self_attn_processor = do_set_processor(name, set_self_attn_module_names)
42
+ set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names)
43
+ set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names)
44
+
45
+ if name.startswith("mid_block"):
46
+ hidden_size = unet.config.block_out_channels[-1]
47
+ elif name.startswith("up_blocks"):
48
+ block_id = int(name[len("up_blocks.")])
49
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
50
+ elif name.startswith("down_blocks"):
51
+ block_id = int(name[len("down_blocks.")])
52
+ hidden_size = unet.config.block_out_channels[block_id]
53
+
54
+ is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name
55
+ if is_custom:
56
+ attn_procs[name] = (
57
+ set_custom_attn_proc_func(name, hidden_size, None, attn_processor)
58
+ if set_custom_attn_processor
59
+ else attn_processor
60
+ )
61
+ else:
62
+ cross_attention_dim = (
63
+ None
64
+ if name.endswith("attn1.processor")
65
+ else unet.config.cross_attention_dim
66
+ )
67
+ if cross_attention_dim is None or "motion_modules" in name:
68
+ # self attention
69
+ attn_procs[name] = (
70
+ set_self_attn_proc_func(
71
+ name, hidden_size, cross_attention_dim, attn_processor
72
+ )
73
+ if set_self_attn_processor
74
+ else attn_processor
75
+ )
76
+ else:
77
+ # cross attention
78
+ attn_procs[name] = (
79
+ set_cross_attn_proc_func(
80
+ name, hidden_size, cross_attention_dim, attn_processor
81
+ )
82
+ if set_cross_attn_processor
83
+ else attn_processor
84
+ )
85
+
86
+ unet.set_attn_processor(attn_procs)
87
+
88
+
89
+ class DecoupledMVRowSelfAttnProcessor2_0(torch.nn.Module):
90
+ r"""
91
+ Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ query_dim: int,
97
+ inner_dim: int,
98
+ num_views: int = 1,
99
+ name: Optional[str] = None,
100
+ use_mv: bool = True,
101
+ use_ref: bool = False,
102
+ ):
103
+ if not hasattr(F, "scaled_dot_product_attention"):
104
+ raise ImportError(
105
+ "DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
106
+ )
107
+
108
+ super().__init__()
109
+
110
+ self.num_views = num_views
111
+ self.name = name # NOTE: need for image cross-attention
112
+ self.use_mv = use_mv
113
+ self.use_ref = use_ref
114
+
115
+ if self.use_mv:
116
+ self.to_q_mv = nn.Linear(
117
+ in_features=query_dim, out_features=inner_dim, bias=False
118
+ )
119
+ self.to_k_mv = nn.Linear(
120
+ in_features=query_dim, out_features=inner_dim, bias=False
121
+ )
122
+ self.to_v_mv = nn.Linear(
123
+ in_features=query_dim, out_features=inner_dim, bias=False
124
+ )
125
+ self.to_out_mv = nn.ModuleList(
126
+ [
127
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
128
+ nn.Dropout(0.0),
129
+ ]
130
+ )
131
+
132
+ if self.use_ref:
133
+ self.to_q_ref = nn.Linear(
134
+ in_features=query_dim, out_features=inner_dim, bias=False
135
+ )
136
+ self.to_k_ref = nn.Linear(
137
+ in_features=query_dim, out_features=inner_dim, bias=False
138
+ )
139
+ self.to_v_ref = nn.Linear(
140
+ in_features=query_dim, out_features=inner_dim, bias=False
141
+ )
142
+ self.to_out_ref = nn.ModuleList(
143
+ [
144
+ nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
145
+ nn.Dropout(0.0),
146
+ ]
147
+ )
148
+
149
+ def __call__(
150
+ self,
151
+ attn: Attention,
152
+ hidden_states: torch.FloatTensor,
153
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
154
+ attention_mask: Optional[torch.FloatTensor] = None,
155
+ temb: Optional[torch.FloatTensor] = None,
156
+ mv_scale: float = 1.0,
157
+ ref_hidden_states: Optional[torch.FloatTensor] = None,
158
+ ref_scale: float = 1.0,
159
+ cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
160
+ use_mv: bool = True,
161
+ use_ref: bool = True,
162
+ *args,
163
+ **kwargs,
164
+ ) -> torch.FloatTensor:
165
+ """
166
+ New args:
167
+ mv_scale (float): scale for multi-view self-attention.
168
+ ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
169
+ ref_scale (float): scale for image cross-attention.
170
+ cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
171
+
172
+ """
173
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
174
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
175
+ deprecate("scale", "1.0.0", deprecation_message)
176
+
177
+ # NEW: cache hidden states for reference unet
178
+ if cache_hidden_states is not None:
179
+ cache_hidden_states[self.name] = hidden_states.clone()
180
+
181
+ # NEW: whether to use multi-view attention and image cross-attention
182
+ use_mv = self.use_mv and use_mv
183
+ use_ref = self.use_ref and use_ref
184
+
185
+ residual = hidden_states
186
+ if attn.spatial_norm is not None:
187
+ hidden_states = attn.spatial_norm(hidden_states, temb)
188
+
189
+ input_ndim = hidden_states.ndim
190
+
191
+ if input_ndim == 4:
192
+ batch_size, channel, height, width = hidden_states.shape
193
+ hidden_states = hidden_states.view(
194
+ batch_size, channel, height * width
195
+ ).transpose(1, 2)
196
+
197
+ batch_size, sequence_length, _ = (
198
+ hidden_states.shape
199
+ if encoder_hidden_states is None
200
+ else encoder_hidden_states.shape
201
+ )
202
+
203
+ if attention_mask is not None:
204
+ attention_mask = attn.prepare_attention_mask(
205
+ attention_mask, sequence_length, batch_size
206
+ )
207
+ # scaled_dot_product_attention expects attention_mask shape to be
208
+ # (batch, heads, source_length, target_length)
209
+ attention_mask = attention_mask.view(
210
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
211
+ )
212
+
213
+ if attn.group_norm is not None:
214
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
215
+ 1, 2
216
+ )
217
+
218
+ query = attn.to_q(hidden_states)
219
+
220
+ # NEW: for decoupled multi-view attention
221
+ if use_mv:
222
+ query_mv = self.to_q_mv(hidden_states)
223
+
224
+ # NEW: for decoupled reference cross attention
225
+ if use_ref:
226
+ query_ref = self.to_q_ref(hidden_states)
227
+
228
+ if encoder_hidden_states is None:
229
+ encoder_hidden_states = hidden_states
230
+ elif attn.norm_cross:
231
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
232
+ encoder_hidden_states
233
+ )
234
+
235
+ key = attn.to_k(encoder_hidden_states)
236
+ value = attn.to_v(encoder_hidden_states)
237
+
238
+ inner_dim = key.shape[-1]
239
+ head_dim = inner_dim // attn.heads
240
+
241
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
242
+
243
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
244
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
245
+
246
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
247
+ # TODO: add support for attn.scale when we move to Torch 2.1
248
+ hidden_states = F.scaled_dot_product_attention(
249
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
250
+ )
251
+
252
+ hidden_states = hidden_states.transpose(1, 2).reshape(
253
+ batch_size, -1, attn.heads * head_dim
254
+ )
255
+ hidden_states = hidden_states.to(query.dtype)
256
+
257
+ ####### Decoupled multi-view self-attention ########
258
+ if use_mv:
259
+ key_mv = self.to_k_mv(encoder_hidden_states)
260
+ value_mv = self.to_v_mv(encoder_hidden_states)
261
+
262
+ query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
263
+ key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
264
+ value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
265
+
266
+ height = width = math.isqrt(sequence_length)
267
+
268
+ # row self-attention
269
+ query_mv = rearrange(
270
+ query_mv,
271
+ "(b nv) (ih iw) h c -> (b nv ih) iw h c",
272
+ nv=self.num_views,
273
+ ih=height,
274
+ iw=width,
275
+ ).transpose(1, 2)
276
+ key_mv = rearrange(
277
+ key_mv,
278
+ "(b nv) (ih iw) h c -> b ih (nv iw) h c",
279
+ nv=self.num_views,
280
+ ih=height,
281
+ iw=width,
282
+ )
283
+ key_mv = (
284
+ key_mv.repeat_interleave(self.num_views, dim=0)
285
+ .view(batch_size * height, -1, attn.heads, head_dim)
286
+ .transpose(1, 2)
287
+ )
288
+ value_mv = rearrange(
289
+ value_mv,
290
+ "(b nv) (ih iw) h c -> b ih (nv iw) h c",
291
+ nv=self.num_views,
292
+ ih=height,
293
+ iw=width,
294
+ )
295
+ value_mv = (
296
+ value_mv.repeat_interleave(self.num_views, dim=0)
297
+ .view(batch_size * height, -1, attn.heads, head_dim)
298
+ .transpose(1, 2)
299
+ )
300
+
301
+ hidden_states_mv = F.scaled_dot_product_attention(
302
+ query_mv,
303
+ key_mv,
304
+ value_mv,
305
+ dropout_p=0.0,
306
+ is_causal=False,
307
+ )
308
+ hidden_states_mv = rearrange(
309
+ hidden_states_mv,
310
+ "(b nv ih) h iw c -> (b nv) (ih iw) (h c)",
311
+ nv=self.num_views,
312
+ ih=height,
313
+ )
314
+ hidden_states_mv = hidden_states_mv.to(query.dtype)
315
+
316
+ # linear proj
317
+ hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
318
+ # dropout
319
+ hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
320
+
321
+ if use_ref:
322
+ reference_hidden_states = ref_hidden_states[self.name]
323
+
324
+ key_ref = self.to_k_ref(reference_hidden_states)
325
+ value_ref = self.to_v_ref(reference_hidden_states)
326
+
327
+ query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
328
+ 1, 2
329
+ )
330
+ key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
331
+ value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
332
+ 1, 2
333
+ )
334
+
335
+ hidden_states_ref = F.scaled_dot_product_attention(
336
+ query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
337
+ )
338
+
339
+ hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
340
+ batch_size, -1, attn.heads * head_dim
341
+ )
342
+ hidden_states_ref = hidden_states_ref.to(query.dtype)
343
+
344
+ # linear proj
345
+ hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
346
+ # dropout
347
+ hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
348
+
349
+ # linear proj
350
+ hidden_states = attn.to_out[0](hidden_states)
351
+ # dropout
352
+ hidden_states = attn.to_out[1](hidden_states)
353
+
354
+ if use_mv:
355
+ hidden_states = hidden_states + hidden_states_mv * mv_scale
356
+
357
+ if use_ref:
358
+ hidden_states = hidden_states + hidden_states_ref * ref_scale
359
+
360
+ if input_ndim == 4:
361
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
362
+ batch_size, channel, height, width
363
+ )
364
+
365
+ if attn.residual_connection:
366
+ hidden_states = hidden_states + residual
367
+
368
+ hidden_states = hidden_states / attn.rescale_output_factor
369
+
370
+ return hidden_states
371
+
372
+ def set_num_views(self, num_views: int) -> None:
373
+ self.num_views = num_views
mvadapter/pipelines/pipeline_mvadapter_i2mv_sdxl.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ import torch.nn as nn
22
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
23
+ from diffusers.models import (
24
+ AutoencoderKL,
25
+ ImageProjection,
26
+ T2IAdapter,
27
+ UNet2DConditionModel,
28
+ )
29
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
30
+ StableDiffusionXLPipelineOutput,
31
+ )
32
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
33
+ StableDiffusionXLPipeline,
34
+ rescale_noise_cfg,
35
+ retrieve_timesteps,
36
+ )
37
+ from diffusers.schedulers import KarrasDiffusionSchedulers
38
+ from diffusers.utils import deprecate, logging
39
+ from diffusers.utils.torch_utils import randn_tensor
40
+ from einops import rearrange
41
+ from transformers import (
42
+ CLIPImageProcessor,
43
+ CLIPTextModel,
44
+ CLIPTextModelWithProjection,
45
+ CLIPTokenizer,
46
+ CLIPVisionModelWithProjection,
47
+ )
48
+
49
+ from ..loaders import CustomAdapterMixin
50
+ from ..models.attention_processor import (
51
+ DecoupledMVRowSelfAttnProcessor2_0,
52
+ set_unet_2d_condition_attn_processor,
53
+ )
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ def retrieve_latents(
59
+ encoder_output: torch.Tensor,
60
+ generator: Optional[torch.Generator] = None,
61
+ sample_mode: str = "sample",
62
+ ):
63
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
64
+ return encoder_output.latent_dist.sample(generator)
65
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
66
+ return encoder_output.latent_dist.mode()
67
+ elif hasattr(encoder_output, "latents"):
68
+ return encoder_output.latents
69
+ else:
70
+ raise AttributeError("Could not access latents of provided encoder_output")
71
+
72
+
73
+ class MVAdapterI2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
74
+ def __init__(
75
+ self,
76
+ vae: AutoencoderKL,
77
+ text_encoder: CLIPTextModel,
78
+ text_encoder_2: CLIPTextModelWithProjection,
79
+ tokenizer: CLIPTokenizer,
80
+ tokenizer_2: CLIPTokenizer,
81
+ unet: UNet2DConditionModel,
82
+ scheduler: KarrasDiffusionSchedulers,
83
+ image_encoder: CLIPVisionModelWithProjection = None,
84
+ feature_extractor: CLIPImageProcessor = None,
85
+ force_zeros_for_empty_prompt: bool = True,
86
+ add_watermarker: Optional[bool] = None,
87
+ ):
88
+ super().__init__(
89
+ vae=vae,
90
+ text_encoder=text_encoder,
91
+ text_encoder_2=text_encoder_2,
92
+ tokenizer=tokenizer,
93
+ tokenizer_2=tokenizer_2,
94
+ unet=unet,
95
+ scheduler=scheduler,
96
+ image_encoder=image_encoder,
97
+ feature_extractor=feature_extractor,
98
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
99
+ add_watermarker=add_watermarker,
100
+ )
101
+
102
+ self.control_image_processor = VaeImageProcessor(
103
+ vae_scale_factor=self.vae_scale_factor,
104
+ do_convert_rgb=True,
105
+ do_normalize=False,
106
+ )
107
+
108
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.prepare_latents
109
+ def prepare_image_latents(
110
+ self,
111
+ image,
112
+ timestep,
113
+ batch_size,
114
+ num_images_per_prompt,
115
+ dtype,
116
+ device,
117
+ generator=None,
118
+ add_noise=True,
119
+ ):
120
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
121
+ raise ValueError(
122
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
123
+ )
124
+
125
+ latents_mean = latents_std = None
126
+ if (
127
+ hasattr(self.vae.config, "latents_mean")
128
+ and self.vae.config.latents_mean is not None
129
+ ):
130
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
131
+ if (
132
+ hasattr(self.vae.config, "latents_std")
133
+ and self.vae.config.latents_std is not None
134
+ ):
135
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
136
+
137
+ # Offload text encoder if `enable_model_cpu_offload` was enabled
138
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
139
+ self.text_encoder_2.to("cpu")
140
+ torch.cuda.empty_cache()
141
+
142
+ image = image.to(device=device, dtype=dtype)
143
+
144
+ batch_size = batch_size * num_images_per_prompt
145
+
146
+ if image.shape[1] == 4:
147
+ init_latents = image
148
+
149
+ else:
150
+ # make sure the VAE is in float32 mode, as it overflows in float16
151
+ if self.vae.config.force_upcast:
152
+ image = image.float()
153
+ self.vae.to(dtype=torch.float32)
154
+
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ elif isinstance(generator, list):
162
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
163
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
164
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
165
+ raise ValueError(
166
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
167
+ )
168
+
169
+ init_latents = [
170
+ retrieve_latents(
171
+ self.vae.encode(image[i : i + 1]), generator=generator[i]
172
+ )
173
+ for i in range(batch_size)
174
+ ]
175
+ init_latents = torch.cat(init_latents, dim=0)
176
+ else:
177
+ init_latents = retrieve_latents(
178
+ self.vae.encode(image), generator=generator
179
+ )
180
+
181
+ if self.vae.config.force_upcast:
182
+ self.vae.to(dtype)
183
+
184
+ init_latents = init_latents.to(dtype)
185
+ if latents_mean is not None and latents_std is not None:
186
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
187
+ latents_std = latents_std.to(device=device, dtype=dtype)
188
+ init_latents = (
189
+ (init_latents - latents_mean)
190
+ * self.vae.config.scaling_factor
191
+ / latents_std
192
+ )
193
+ else:
194
+ init_latents = self.vae.config.scaling_factor * init_latents
195
+
196
+ if (
197
+ batch_size > init_latents.shape[0]
198
+ and batch_size % init_latents.shape[0] == 0
199
+ ):
200
+ # expand init_latents for batch_size
201
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
202
+ init_latents = torch.cat(
203
+ [init_latents] * additional_image_per_prompt, dim=0
204
+ )
205
+ elif (
206
+ batch_size > init_latents.shape[0]
207
+ and batch_size % init_latents.shape[0] != 0
208
+ ):
209
+ raise ValueError(
210
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
211
+ )
212
+ else:
213
+ init_latents = torch.cat([init_latents], dim=0)
214
+
215
+ if add_noise:
216
+ shape = init_latents.shape
217
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
218
+ # get latents
219
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
220
+
221
+ latents = init_latents
222
+
223
+ return latents
224
+
225
+ def prepare_control_image(
226
+ self,
227
+ image,
228
+ width,
229
+ height,
230
+ batch_size,
231
+ num_images_per_prompt,
232
+ device,
233
+ dtype,
234
+ do_classifier_free_guidance=False,
235
+ num_empty_images=0, # for concat in batch like ImageDream
236
+ ):
237
+ assert hasattr(
238
+ self, "control_image_processor"
239
+ ), "control_image_processor is not initialized"
240
+
241
+ image = self.control_image_processor.preprocess(
242
+ image, height=height, width=width
243
+ ).to(dtype=torch.float32)
244
+
245
+ if num_empty_images > 0:
246
+ image = torch.cat(
247
+ [image, torch.zeros_like(image[:num_empty_images])], dim=0
248
+ )
249
+
250
+ image_batch_size = image.shape[0]
251
+
252
+ if image_batch_size == 1:
253
+ repeat_by = batch_size
254
+ else:
255
+ # image batch size is the same as prompt batch size
256
+ repeat_by = num_images_per_prompt # always 1 for control image
257
+
258
+ image = image.repeat_interleave(repeat_by, dim=0)
259
+
260
+ image = image.to(device=device, dtype=dtype)
261
+
262
+ if do_classifier_free_guidance:
263
+ image = torch.cat([image] * 2)
264
+
265
+ return image
266
+
267
+ @torch.no_grad()
268
+ def __call__(
269
+ self,
270
+ prompt: Union[str, List[str]] = None,
271
+ prompt_2: Optional[Union[str, List[str]]] = None,
272
+ height: Optional[int] = None,
273
+ width: Optional[int] = None,
274
+ num_inference_steps: int = 50,
275
+ timesteps: List[int] = None,
276
+ denoising_end: Optional[float] = None,
277
+ guidance_scale: float = 5.0,
278
+ negative_prompt: Optional[Union[str, List[str]]] = None,
279
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
280
+ num_images_per_prompt: Optional[int] = 1,
281
+ eta: float = 0.0,
282
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
283
+ latents: Optional[torch.FloatTensor] = None,
284
+ prompt_embeds: Optional[torch.FloatTensor] = None,
285
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
286
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
287
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
288
+ ip_adapter_image: Optional[PipelineImageInput] = None,
289
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
290
+ output_type: Optional[str] = "pil",
291
+ return_dict: bool = True,
292
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
293
+ guidance_rescale: float = 0.0,
294
+ original_size: Optional[Tuple[int, int]] = None,
295
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
296
+ target_size: Optional[Tuple[int, int]] = None,
297
+ negative_original_size: Optional[Tuple[int, int]] = None,
298
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
299
+ negative_target_size: Optional[Tuple[int, int]] = None,
300
+ clip_skip: Optional[int] = None,
301
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
302
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
303
+ # NEW
304
+ mv_scale: float = 1.0,
305
+ # Camera or geometry condition
306
+ control_image: Optional[PipelineImageInput] = None,
307
+ control_conditioning_scale: Optional[float] = 1.0,
308
+ control_conditioning_factor: float = 1.0,
309
+ # Image condition
310
+ reference_image: Optional[PipelineImageInput] = None,
311
+ reference_conditioning_scale: Optional[float] = 1.0,
312
+ **kwargs,
313
+ ):
314
+ r"""
315
+ Function invoked when calling the pipeline for generation.
316
+
317
+ Args:
318
+ prompt (`str` or `List[str]`, *optional*):
319
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
320
+ instead.
321
+ prompt_2 (`str` or `List[str]`, *optional*):
322
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
323
+ used in both text-encoders
324
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
325
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
326
+ Anything below 512 pixels won't work well for
327
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
328
+ and checkpoints that are not specifically fine-tuned on low resolutions.
329
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
330
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
331
+ Anything below 512 pixels won't work well for
332
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
333
+ and checkpoints that are not specifically fine-tuned on low resolutions.
334
+ num_inference_steps (`int`, *optional*, defaults to 50):
335
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
336
+ expense of slower inference.
337
+ timesteps (`List[int]`, *optional*):
338
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
339
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
340
+ passed will be used. Must be in descending order.
341
+ denoising_end (`float`, *optional*):
342
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
343
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
344
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
345
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
346
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
347
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
348
+ guidance_scale (`float`, *optional*, defaults to 5.0):
349
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
350
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
351
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
352
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
353
+ usually at the expense of lower image quality.
354
+ negative_prompt (`str` or `List[str]`, *optional*):
355
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
356
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
357
+ less than `1`).
358
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
359
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
360
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
361
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
362
+ The number of images to generate per prompt.
363
+ eta (`float`, *optional*, defaults to 0.0):
364
+ Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
365
+ [`schedulers.DDIMScheduler`], will be ignored for others.
366
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
367
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
368
+ to make generation deterministic.
369
+ latents (`torch.FloatTensor`, *optional*):
370
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
371
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
372
+ tensor will ge generated by sampling using the supplied random `generator`.
373
+ prompt_embeds (`torch.FloatTensor`, *optional*):
374
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
375
+ provided, text embeddings will be generated from `prompt` input argument.
376
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
377
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
378
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
379
+ argument.
380
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
381
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
382
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
383
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
384
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
385
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
386
+ input argument.
387
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
388
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
389
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
390
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
391
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
392
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
393
+ output_type (`str`, *optional*, defaults to `"pil"`):
394
+ The output format of the generate image. Choose between
395
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
396
+ return_dict (`bool`, *optional*, defaults to `True`):
397
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
398
+ of a plain tuple.
399
+ cross_attention_kwargs (`dict`, *optional*):
400
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
401
+ `self.processor` in
402
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
403
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
404
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
405
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `ฯ†` in equation 16. of
406
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
407
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
408
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
409
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
410
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
411
+ explained in section 2.2 of
412
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
413
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
414
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
415
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
416
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
417
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
418
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
419
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
420
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
421
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
422
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
423
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
424
+ micro-conditioning as explained in section 2.2 of
425
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
426
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
427
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
428
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
429
+ micro-conditioning as explained in section 2.2 of
430
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
431
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
432
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
433
+ To negatively condition the generation process based on a target image resolution. It should be as same
434
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
435
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
436
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
437
+ callback_on_step_end (`Callable`, *optional*):
438
+ A function that calls at the end of each denoising steps during the inference. The function is called
439
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
440
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
441
+ `callback_on_step_end_tensor_inputs`.
442
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
443
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
444
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
445
+ `._callback_tensor_inputs` attribute of your pipeline class.
446
+
447
+ Examples:
448
+
449
+ Returns:
450
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
451
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
452
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
453
+ """
454
+
455
+ callback = kwargs.pop("callback", None)
456
+ callback_steps = kwargs.pop("callback_steps", None)
457
+
458
+ if callback is not None:
459
+ deprecate(
460
+ "callback",
461
+ "1.0.0",
462
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
463
+ )
464
+ if callback_steps is not None:
465
+ deprecate(
466
+ "callback_steps",
467
+ "1.0.0",
468
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
469
+ )
470
+
471
+ # 0. Default height and width to unet
472
+ height = height or self.default_sample_size * self.vae_scale_factor
473
+ width = width or self.default_sample_size * self.vae_scale_factor
474
+
475
+ original_size = original_size or (height, width)
476
+ target_size = target_size or (height, width)
477
+
478
+ # 1. Check inputs. Raise error if not correct
479
+ self.check_inputs(
480
+ prompt,
481
+ prompt_2,
482
+ height,
483
+ width,
484
+ callback_steps,
485
+ negative_prompt,
486
+ negative_prompt_2,
487
+ prompt_embeds,
488
+ negative_prompt_embeds,
489
+ pooled_prompt_embeds,
490
+ negative_pooled_prompt_embeds,
491
+ ip_adapter_image,
492
+ ip_adapter_image_embeds,
493
+ callback_on_step_end_tensor_inputs,
494
+ )
495
+
496
+ self._guidance_scale = guidance_scale
497
+ self._guidance_rescale = guidance_rescale
498
+ self._clip_skip = clip_skip
499
+ self._cross_attention_kwargs = cross_attention_kwargs
500
+ self._denoising_end = denoising_end
501
+ self._interrupt = False
502
+
503
+ # 2. Define call parameters
504
+ if prompt is not None and isinstance(prompt, str):
505
+ batch_size = 1
506
+ elif prompt is not None and isinstance(prompt, list):
507
+ batch_size = len(prompt)
508
+ else:
509
+ batch_size = prompt_embeds.shape[0]
510
+
511
+ device = self._execution_device
512
+
513
+ # 3. Encode input prompt
514
+ lora_scale = (
515
+ self.cross_attention_kwargs.get("scale", None)
516
+ if self.cross_attention_kwargs is not None
517
+ else None
518
+ )
519
+
520
+ (
521
+ prompt_embeds,
522
+ negative_prompt_embeds,
523
+ pooled_prompt_embeds,
524
+ negative_pooled_prompt_embeds,
525
+ ) = self.encode_prompt(
526
+ prompt=prompt,
527
+ prompt_2=prompt_2,
528
+ device=device,
529
+ num_images_per_prompt=num_images_per_prompt,
530
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
531
+ negative_prompt=negative_prompt,
532
+ negative_prompt_2=negative_prompt_2,
533
+ prompt_embeds=prompt_embeds,
534
+ negative_prompt_embeds=negative_prompt_embeds,
535
+ pooled_prompt_embeds=pooled_prompt_embeds,
536
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
537
+ lora_scale=lora_scale,
538
+ clip_skip=self.clip_skip,
539
+ )
540
+
541
+ # 4. Prepare timesteps
542
+ timesteps, num_inference_steps = retrieve_timesteps(
543
+ self.scheduler, num_inference_steps, device, timesteps
544
+ )
545
+
546
+ # 5. Prepare latent variables
547
+ num_channels_latents = self.unet.config.in_channels
548
+ latents = self.prepare_latents(
549
+ batch_size * num_images_per_prompt,
550
+ num_channels_latents,
551
+ height,
552
+ width,
553
+ prompt_embeds.dtype,
554
+ device,
555
+ generator,
556
+ latents,
557
+ )
558
+
559
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
560
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
561
+
562
+ # 7. Prepare added time ids & embeddings
563
+ add_text_embeds = pooled_prompt_embeds
564
+ if self.text_encoder_2 is None:
565
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
566
+ else:
567
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
568
+
569
+ add_time_ids = self._get_add_time_ids(
570
+ original_size,
571
+ crops_coords_top_left,
572
+ target_size,
573
+ dtype=prompt_embeds.dtype,
574
+ text_encoder_projection_dim=text_encoder_projection_dim,
575
+ )
576
+ if negative_original_size is not None and negative_target_size is not None:
577
+ negative_add_time_ids = self._get_add_time_ids(
578
+ negative_original_size,
579
+ negative_crops_coords_top_left,
580
+ negative_target_size,
581
+ dtype=prompt_embeds.dtype,
582
+ text_encoder_projection_dim=text_encoder_projection_dim,
583
+ )
584
+ else:
585
+ negative_add_time_ids = add_time_ids
586
+
587
+ if self.do_classifier_free_guidance:
588
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
589
+ add_text_embeds = torch.cat(
590
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
591
+ )
592
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
593
+
594
+ prompt_embeds = prompt_embeds.to(device)
595
+ add_text_embeds = add_text_embeds.to(device)
596
+ add_time_ids = add_time_ids.to(device).repeat(
597
+ batch_size * num_images_per_prompt, 1
598
+ )
599
+
600
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
601
+ image_embeds = self.prepare_ip_adapter_image_embeds(
602
+ ip_adapter_image,
603
+ ip_adapter_image_embeds,
604
+ device,
605
+ batch_size * num_images_per_prompt,
606
+ self.do_classifier_free_guidance,
607
+ )
608
+
609
+ # Preprocess reference image
610
+ reference_image = self.image_processor.preprocess(reference_image)
611
+ reference_latents = self.prepare_image_latents(
612
+ reference_image,
613
+ timesteps[:1].repeat(batch_size * num_images_per_prompt), # no use
614
+ batch_size,
615
+ 1,
616
+ prompt_embeds.dtype,
617
+ device,
618
+ generator,
619
+ add_noise=False,
620
+ )
621
+
622
+ with torch.no_grad():
623
+ ref_timesteps = torch.zeros_like(timesteps[0])
624
+ ref_hidden_states = {}
625
+
626
+ self.unet(
627
+ reference_latents,
628
+ ref_timesteps,
629
+ encoder_hidden_states=prompt_embeds[-1:],
630
+ added_cond_kwargs={
631
+ "text_embeds": add_text_embeds[-1:],
632
+ "time_ids": add_time_ids[-1:],
633
+ },
634
+ cross_attention_kwargs={
635
+ "cache_hidden_states": ref_hidden_states,
636
+ "use_mv": False,
637
+ "use_ref": False,
638
+ },
639
+ return_dict=False,
640
+ )
641
+ ref_hidden_states = {
642
+ k: v.repeat_interleave(num_images_per_prompt, dim=0)
643
+ for k, v in ref_hidden_states.items()
644
+ }
645
+ if self.do_classifier_free_guidance:
646
+ ref_hidden_states = {
647
+ k: torch.cat([torch.zeros_like(v), v], dim=0)
648
+ for k, v in ref_hidden_states.items()
649
+ }
650
+
651
+ cross_attention_kwargs = {
652
+ "mv_scale": mv_scale,
653
+ "ref_hidden_states": {k: v.clone() for k, v in ref_hidden_states.items()},
654
+ "ref_scale": reference_conditioning_scale,
655
+ **(self.cross_attention_kwargs or {}),
656
+ }
657
+
658
+ # Preprocess control image
659
+ control_image_feature = self.prepare_control_image(
660
+ image=control_image,
661
+ width=width,
662
+ height=height,
663
+ batch_size=batch_size * num_images_per_prompt,
664
+ num_images_per_prompt=1, # NOTE: always 1 for control images
665
+ device=device,
666
+ dtype=latents.dtype,
667
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
668
+ )
669
+ control_image_feature = control_image_feature.to(
670
+ device=device, dtype=latents.dtype
671
+ )
672
+
673
+ adapter_state = self.cond_encoder(control_image_feature)
674
+ for i, state in enumerate(adapter_state):
675
+ adapter_state[i] = state * control_conditioning_scale
676
+
677
+ # 8. Denoising loop
678
+ num_warmup_steps = max(
679
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
680
+ )
681
+
682
+ # 8.1 Apply denoising_end
683
+ if (
684
+ self.denoising_end is not None
685
+ and isinstance(self.denoising_end, float)
686
+ and self.denoising_end > 0
687
+ and self.denoising_end < 1
688
+ ):
689
+ discrete_timestep_cutoff = int(
690
+ round(
691
+ self.scheduler.config.num_train_timesteps
692
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
693
+ )
694
+ )
695
+ num_inference_steps = len(
696
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
697
+ )
698
+ timesteps = timesteps[:num_inference_steps]
699
+
700
+ # 9. Optionally get Guidance Scale Embedding
701
+ timestep_cond = None
702
+ if self.unet.config.time_cond_proj_dim is not None:
703
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
704
+ batch_size * num_images_per_prompt
705
+ )
706
+ timestep_cond = self.get_guidance_scale_embedding(
707
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
708
+ ).to(device=device, dtype=latents.dtype)
709
+
710
+ self._num_timesteps = len(timesteps)
711
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
712
+ for i, t in enumerate(timesteps):
713
+ if self.interrupt:
714
+ continue
715
+
716
+ # expand the latents if we are doing classifier free guidance
717
+ latent_model_input = (
718
+ torch.cat([latents] * 2)
719
+ if self.do_classifier_free_guidance
720
+ else latents
721
+ )
722
+
723
+ latent_model_input = self.scheduler.scale_model_input(
724
+ latent_model_input, t
725
+ )
726
+
727
+ added_cond_kwargs = {
728
+ "text_embeds": add_text_embeds,
729
+ "time_ids": add_time_ids,
730
+ }
731
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
732
+ added_cond_kwargs["image_embeds"] = image_embeds
733
+
734
+ if i < int(num_inference_steps * control_conditioning_factor):
735
+ down_intrablock_additional_residuals = [
736
+ state.clone() for state in adapter_state
737
+ ]
738
+ else:
739
+ down_intrablock_additional_residuals = None
740
+
741
+ # predict the noise residual
742
+ noise_pred = self.unet(
743
+ latent_model_input,
744
+ t,
745
+ encoder_hidden_states=prompt_embeds,
746
+ timestep_cond=timestep_cond,
747
+ cross_attention_kwargs=cross_attention_kwargs,
748
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
749
+ added_cond_kwargs=added_cond_kwargs,
750
+ return_dict=False,
751
+ )[0]
752
+
753
+ # perform guidance
754
+ if self.do_classifier_free_guidance:
755
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
756
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
757
+ noise_pred_text - noise_pred_uncond
758
+ )
759
+
760
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
761
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
762
+ noise_pred = rescale_noise_cfg(
763
+ noise_pred,
764
+ noise_pred_text,
765
+ guidance_rescale=self.guidance_rescale,
766
+ )
767
+
768
+ # compute the previous noisy sample x_t -> x_t-1
769
+ latents_dtype = latents.dtype
770
+ latents = self.scheduler.step(
771
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
772
+ )[0]
773
+ if latents.dtype != latents_dtype:
774
+ if torch.backends.mps.is_available():
775
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
776
+ latents = latents.to(latents_dtype)
777
+
778
+ if callback_on_step_end is not None:
779
+ callback_kwargs = {}
780
+ for k in callback_on_step_end_tensor_inputs:
781
+ callback_kwargs[k] = locals()[k]
782
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
783
+
784
+ latents = callback_outputs.pop("latents", latents)
785
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
786
+ negative_prompt_embeds = callback_outputs.pop(
787
+ "negative_prompt_embeds", negative_prompt_embeds
788
+ )
789
+ add_text_embeds = callback_outputs.pop(
790
+ "add_text_embeds", add_text_embeds
791
+ )
792
+ negative_pooled_prompt_embeds = callback_outputs.pop(
793
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
794
+ )
795
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
796
+ negative_add_time_ids = callback_outputs.pop(
797
+ "negative_add_time_ids", negative_add_time_ids
798
+ )
799
+
800
+ # call the callback, if provided
801
+ if i == len(timesteps) - 1 or (
802
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
803
+ ):
804
+ progress_bar.update()
805
+ if callback is not None and i % callback_steps == 0:
806
+ step_idx = i // getattr(self.scheduler, "order", 1)
807
+ callback(step_idx, t, latents)
808
+
809
+ if not output_type == "latent":
810
+ # make sure the VAE is in float32 mode, as it overflows in float16
811
+ needs_upcasting = (
812
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
813
+ )
814
+
815
+ if needs_upcasting:
816
+ self.upcast_vae()
817
+ latents = latents.to(
818
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
819
+ )
820
+ elif latents.dtype != self.vae.dtype:
821
+ if torch.backends.mps.is_available():
822
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
823
+ self.vae = self.vae.to(latents.dtype)
824
+
825
+ # unscale/denormalize the latents
826
+ # denormalize with the mean and std if available and not None
827
+ has_latents_mean = (
828
+ hasattr(self.vae.config, "latents_mean")
829
+ and self.vae.config.latents_mean is not None
830
+ )
831
+ has_latents_std = (
832
+ hasattr(self.vae.config, "latents_std")
833
+ and self.vae.config.latents_std is not None
834
+ )
835
+ if has_latents_mean and has_latents_std:
836
+ latents_mean = (
837
+ torch.tensor(self.vae.config.latents_mean)
838
+ .view(1, 4, 1, 1)
839
+ .to(latents.device, latents.dtype)
840
+ )
841
+ latents_std = (
842
+ torch.tensor(self.vae.config.latents_std)
843
+ .view(1, 4, 1, 1)
844
+ .to(latents.device, latents.dtype)
845
+ )
846
+ latents = (
847
+ latents * latents_std / self.vae.config.scaling_factor
848
+ + latents_mean
849
+ )
850
+ else:
851
+ latents = latents / self.vae.config.scaling_factor
852
+
853
+ image = self.vae.decode(latents, return_dict=False)[0]
854
+
855
+ # cast back to fp16 if needed
856
+ if needs_upcasting:
857
+ self.vae.to(dtype=torch.float16)
858
+ else:
859
+ image = latents
860
+
861
+ if not output_type == "latent":
862
+ # apply watermark if available
863
+ if self.watermark is not None:
864
+ image = self.watermark.apply_watermark(image)
865
+
866
+ image = self.image_processor.postprocess(image, output_type=output_type)
867
+
868
+ # Offload all models
869
+ self.maybe_free_model_hooks()
870
+
871
+ if not return_dict:
872
+ return (image,)
873
+
874
+ return StableDiffusionXLPipelineOutput(images=image)
875
+
876
+ ### NEW: adapters ###
877
+ def _init_custom_adapter(
878
+ self,
879
+ # Multi-view adapter
880
+ num_views: int,
881
+ self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
882
+ # Condition encoder
883
+ cond_in_channels: int = 6,
884
+ # For training
885
+ copy_attn_weights: bool = True,
886
+ zero_init_module_keys: List[str] = [],
887
+ ):
888
+ # Condition encoder
889
+ self.cond_encoder = T2IAdapter(
890
+ in_channels=cond_in_channels,
891
+ channels=(320, 640, 1280, 1280),
892
+ num_res_blocks=2,
893
+ downscale_factor=16,
894
+ adapter_type="full_adapter_xl",
895
+ )
896
+
897
+ # set custom attn processor for multi-view attention and image cross-attention
898
+ self.unet: UNet2DConditionModel
899
+ set_unet_2d_condition_attn_processor(
900
+ self.unet,
901
+ set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
902
+ query_dim=hs,
903
+ inner_dim=hs,
904
+ num_views=num_views,
905
+ name=name,
906
+ use_mv=True,
907
+ use_ref=True,
908
+ ),
909
+ )
910
+
911
+ # copy decoupled attention weights from original unet
912
+ if copy_attn_weights:
913
+ state_dict = self.unet.state_dict()
914
+ for key in state_dict.keys():
915
+ if "_mv" in key:
916
+ compatible_key = key.replace("_mv", "").replace("processor.", "")
917
+ elif "_ref" in key:
918
+ compatible_key = key.replace("_ref", "").replace("processor.", "")
919
+ else:
920
+ compatible_key = key
921
+
922
+ is_zero_init_key = any([k in key for k in zero_init_module_keys])
923
+ if is_zero_init_key:
924
+ state_dict[key] = torch.zeros_like(state_dict[compatible_key])
925
+ else:
926
+ state_dict[key] = state_dict[compatible_key].clone()
927
+ self.unet.load_state_dict(state_dict)
928
+
929
+ def _load_custom_adapter(self, state_dict):
930
+ self.unet.load_state_dict(state_dict, strict=False)
931
+ self.cond_encoder.load_state_dict(state_dict, strict=False)
932
+
933
+ def _save_custom_adapter(
934
+ self,
935
+ include_keys: Optional[List[str]] = None,
936
+ exclude_keys: Optional[List[str]] = None,
937
+ ):
938
+ def include_fn(k):
939
+ is_included = False
940
+
941
+ if include_keys is not None:
942
+ is_included = is_included or any([key in k for key in include_keys])
943
+ if exclude_keys is not None:
944
+ is_included = is_included and not any(
945
+ [key in k for key in exclude_keys]
946
+ )
947
+
948
+ return is_included
949
+
950
+ state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
951
+ state_dict.update(self.cond_encoder.state_dict())
952
+
953
+ return state_dict
mvadapter/pipelines/pipeline_mvadapter_t2mv_sdxl.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
19
+ from diffusers.models import AutoencoderKL, T2IAdapter, UNet2DConditionModel
20
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
21
+ StableDiffusionXLPipelineOutput,
22
+ )
23
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
24
+ StableDiffusionXLPipeline,
25
+ rescale_noise_cfg,
26
+ retrieve_timesteps,
27
+ )
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils import deprecate, logging
30
+ from transformers import (
31
+ CLIPImageProcessor,
32
+ CLIPTextModel,
33
+ CLIPTextModelWithProjection,
34
+ CLIPTokenizer,
35
+ CLIPVisionModelWithProjection,
36
+ )
37
+
38
+ from ..loaders import CustomAdapterMixin
39
+ from ..models.attention_processor import (
40
+ DecoupledMVRowSelfAttnProcessor2_0,
41
+ set_unet_2d_condition_attn_processor,
42
+ )
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+
47
+ class MVAdapterT2MVSDXLPipeline(StableDiffusionXLPipeline, CustomAdapterMixin):
48
+ def __init__(
49
+ self,
50
+ vae: AutoencoderKL,
51
+ text_encoder: CLIPTextModel,
52
+ text_encoder_2: CLIPTextModelWithProjection,
53
+ tokenizer: CLIPTokenizer,
54
+ tokenizer_2: CLIPTokenizer,
55
+ unet: UNet2DConditionModel,
56
+ scheduler: KarrasDiffusionSchedulers,
57
+ image_encoder: CLIPVisionModelWithProjection = None,
58
+ feature_extractor: CLIPImageProcessor = None,
59
+ force_zeros_for_empty_prompt: bool = True,
60
+ add_watermarker: Optional[bool] = None,
61
+ ):
62
+ super().__init__(
63
+ vae=vae,
64
+ text_encoder=text_encoder,
65
+ text_encoder_2=text_encoder_2,
66
+ tokenizer=tokenizer,
67
+ tokenizer_2=tokenizer_2,
68
+ unet=unet,
69
+ scheduler=scheduler,
70
+ image_encoder=image_encoder,
71
+ feature_extractor=feature_extractor,
72
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
73
+ add_watermarker=add_watermarker,
74
+ )
75
+
76
+ self.control_image_processor = VaeImageProcessor(
77
+ vae_scale_factor=self.vae_scale_factor,
78
+ do_convert_rgb=True,
79
+ do_normalize=False,
80
+ )
81
+
82
+ def prepare_control_image(
83
+ self,
84
+ image,
85
+ width,
86
+ height,
87
+ batch_size,
88
+ num_images_per_prompt,
89
+ device,
90
+ dtype,
91
+ do_classifier_free_guidance=False,
92
+ ):
93
+ assert hasattr(
94
+ self, "control_image_processor"
95
+ ), "control_image_processor is not initialized"
96
+
97
+ image = self.control_image_processor.preprocess(
98
+ image, height=height, width=width
99
+ ).to(dtype=torch.float32)
100
+ image_batch_size = image.shape[0]
101
+
102
+ if image_batch_size == 1:
103
+ repeat_by = batch_size
104
+ else:
105
+ # image batch size is the same as prompt batch size
106
+ repeat_by = num_images_per_prompt # always 1 for control image
107
+
108
+ image = image.repeat_interleave(repeat_by, dim=0)
109
+
110
+ image = image.to(device=device, dtype=dtype)
111
+
112
+ if do_classifier_free_guidance:
113
+ image = torch.cat([image] * 2)
114
+
115
+ return image
116
+
117
+ @torch.no_grad()
118
+ def __call__(
119
+ self,
120
+ prompt: Union[str, List[str]] = None,
121
+ prompt_2: Optional[Union[str, List[str]]] = None,
122
+ height: Optional[int] = None,
123
+ width: Optional[int] = None,
124
+ num_inference_steps: int = 50,
125
+ timesteps: List[int] = None,
126
+ denoising_end: Optional[float] = None,
127
+ guidance_scale: float = 5.0,
128
+ negative_prompt: Optional[Union[str, List[str]]] = None,
129
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
130
+ num_images_per_prompt: Optional[int] = 1,
131
+ eta: float = 0.0,
132
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
133
+ latents: Optional[torch.FloatTensor] = None,
134
+ prompt_embeds: Optional[torch.FloatTensor] = None,
135
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
136
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
137
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
138
+ ip_adapter_image: Optional[PipelineImageInput] = None,
139
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
140
+ output_type: Optional[str] = "pil",
141
+ return_dict: bool = True,
142
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
143
+ guidance_rescale: float = 0.0,
144
+ original_size: Optional[Tuple[int, int]] = None,
145
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
146
+ target_size: Optional[Tuple[int, int]] = None,
147
+ negative_original_size: Optional[Tuple[int, int]] = None,
148
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
149
+ negative_target_size: Optional[Tuple[int, int]] = None,
150
+ clip_skip: Optional[int] = None,
151
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
152
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
153
+ # NEW
154
+ mv_scale: float = 1.0,
155
+ # Camera or geometry condition
156
+ control_image: Optional[PipelineImageInput] = None,
157
+ control_conditioning_scale: Optional[float] = 1.0,
158
+ control_conditioning_factor: float = 1.0,
159
+ # Optional. controlnet
160
+ controlnet_image: Optional[PipelineImageInput] = None,
161
+ controlnet_conditioning_scale: Optional[float] = 1.0,
162
+ **kwargs,
163
+ ):
164
+ r"""
165
+ Function invoked when calling the pipeline for generation.
166
+
167
+ Args:
168
+ prompt (`str` or `List[str]`, *optional*):
169
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
170
+ instead.
171
+ prompt_2 (`str` or `List[str]`, *optional*):
172
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
173
+ used in both text-encoders
174
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
175
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
176
+ Anything below 512 pixels won't work well for
177
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
178
+ and checkpoints that are not specifically fine-tuned on low resolutions.
179
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
180
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
181
+ Anything below 512 pixels won't work well for
182
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
183
+ and checkpoints that are not specifically fine-tuned on low resolutions.
184
+ num_inference_steps (`int`, *optional*, defaults to 50):
185
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
186
+ expense of slower inference.
187
+ timesteps (`List[int]`, *optional*):
188
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
189
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
190
+ passed will be used. Must be in descending order.
191
+ denoising_end (`float`, *optional*):
192
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
193
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
194
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
195
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
196
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
197
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
198
+ guidance_scale (`float`, *optional*, defaults to 5.0):
199
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
200
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
201
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
202
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
203
+ usually at the expense of lower image quality.
204
+ negative_prompt (`str` or `List[str]`, *optional*):
205
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
206
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
207
+ less than `1`).
208
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
209
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
210
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
211
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
212
+ The number of images to generate per prompt.
213
+ eta (`float`, *optional*, defaults to 0.0):
214
+ Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
215
+ [`schedulers.DDIMScheduler`], will be ignored for others.
216
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
217
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
218
+ to make generation deterministic.
219
+ latents (`torch.FloatTensor`, *optional*):
220
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
221
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
222
+ tensor will ge generated by sampling using the supplied random `generator`.
223
+ prompt_embeds (`torch.FloatTensor`, *optional*):
224
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
225
+ provided, text embeddings will be generated from `prompt` input argument.
226
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
227
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
228
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
229
+ argument.
230
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
231
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
232
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
233
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
234
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
235
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
236
+ input argument.
237
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
238
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
239
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
240
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
241
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
242
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
243
+ output_type (`str`, *optional*, defaults to `"pil"`):
244
+ The output format of the generate image. Choose between
245
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
246
+ return_dict (`bool`, *optional*, defaults to `True`):
247
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
248
+ of a plain tuple.
249
+ cross_attention_kwargs (`dict`, *optional*):
250
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
251
+ `self.processor` in
252
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
253
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
254
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
255
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `ฯ†` in equation 16. of
256
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
257
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
258
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
259
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
260
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
261
+ explained in section 2.2 of
262
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
263
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
264
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
265
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
266
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
267
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
268
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
269
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
270
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
271
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
272
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
273
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
274
+ micro-conditioning as explained in section 2.2 of
275
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
276
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
277
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
278
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
279
+ micro-conditioning as explained in section 2.2 of
280
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
281
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
282
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
283
+ To negatively condition the generation process based on a target image resolution. It should be as same
284
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
285
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
286
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
287
+ callback_on_step_end (`Callable`, *optional*):
288
+ A function that calls at the end of each denoising steps during the inference. The function is called
289
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
290
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
291
+ `callback_on_step_end_tensor_inputs`.
292
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
293
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
294
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
295
+ `._callback_tensor_inputs` attribute of your pipeline class.
296
+
297
+ Examples:
298
+
299
+ Returns:
300
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
301
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
302
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
303
+ """
304
+
305
+ callback = kwargs.pop("callback", None)
306
+ callback_steps = kwargs.pop("callback_steps", None)
307
+
308
+ if callback is not None:
309
+ deprecate(
310
+ "callback",
311
+ "1.0.0",
312
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
313
+ )
314
+ if callback_steps is not None:
315
+ deprecate(
316
+ "callback_steps",
317
+ "1.0.0",
318
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
319
+ )
320
+
321
+ # 0. Default height and width to unet
322
+ height = height or self.default_sample_size * self.vae_scale_factor
323
+ width = width or self.default_sample_size * self.vae_scale_factor
324
+
325
+ original_size = original_size or (height, width)
326
+ target_size = target_size or (height, width)
327
+
328
+ # 1. Check inputs. Raise error if not correct
329
+ self.check_inputs(
330
+ prompt,
331
+ prompt_2,
332
+ height,
333
+ width,
334
+ callback_steps,
335
+ negative_prompt,
336
+ negative_prompt_2,
337
+ prompt_embeds,
338
+ negative_prompt_embeds,
339
+ pooled_prompt_embeds,
340
+ negative_pooled_prompt_embeds,
341
+ ip_adapter_image,
342
+ ip_adapter_image_embeds,
343
+ callback_on_step_end_tensor_inputs,
344
+ )
345
+
346
+ self._guidance_scale = guidance_scale
347
+ self._guidance_rescale = guidance_rescale
348
+ self._clip_skip = clip_skip
349
+ self._cross_attention_kwargs = cross_attention_kwargs
350
+ self._denoising_end = denoising_end
351
+ self._interrupt = False
352
+
353
+ # 2. Define call parameters
354
+ if prompt is not None and isinstance(prompt, str):
355
+ batch_size = 1
356
+ elif prompt is not None and isinstance(prompt, list):
357
+ batch_size = len(prompt)
358
+ else:
359
+ batch_size = prompt_embeds.shape[0]
360
+
361
+ device = self._execution_device
362
+
363
+ # 3. Encode input prompt
364
+ lora_scale = (
365
+ self.cross_attention_kwargs.get("scale", None)
366
+ if self.cross_attention_kwargs is not None
367
+ else None
368
+ )
369
+
370
+ (
371
+ prompt_embeds,
372
+ negative_prompt_embeds,
373
+ pooled_prompt_embeds,
374
+ negative_pooled_prompt_embeds,
375
+ ) = self.encode_prompt(
376
+ prompt=prompt,
377
+ prompt_2=prompt_2,
378
+ device=device,
379
+ num_images_per_prompt=num_images_per_prompt,
380
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
381
+ negative_prompt=negative_prompt,
382
+ negative_prompt_2=negative_prompt_2,
383
+ prompt_embeds=prompt_embeds,
384
+ negative_prompt_embeds=negative_prompt_embeds,
385
+ pooled_prompt_embeds=pooled_prompt_embeds,
386
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
387
+ lora_scale=lora_scale,
388
+ clip_skip=self.clip_skip,
389
+ )
390
+
391
+ # 4. Prepare timesteps
392
+ timesteps, num_inference_steps = retrieve_timesteps(
393
+ self.scheduler, num_inference_steps, device, timesteps
394
+ )
395
+
396
+ # 5. Prepare latent variables
397
+ num_channels_latents = self.unet.config.in_channels
398
+ latents = self.prepare_latents(
399
+ batch_size * num_images_per_prompt,
400
+ num_channels_latents,
401
+ height,
402
+ width,
403
+ prompt_embeds.dtype,
404
+ device,
405
+ generator,
406
+ latents,
407
+ )
408
+
409
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
410
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
411
+
412
+ # 7. Prepare added time ids & embeddings
413
+ add_text_embeds = pooled_prompt_embeds
414
+ if self.text_encoder_2 is None:
415
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
416
+ else:
417
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
418
+
419
+ add_time_ids = self._get_add_time_ids(
420
+ original_size,
421
+ crops_coords_top_left,
422
+ target_size,
423
+ dtype=prompt_embeds.dtype,
424
+ text_encoder_projection_dim=text_encoder_projection_dim,
425
+ )
426
+ if negative_original_size is not None and negative_target_size is not None:
427
+ negative_add_time_ids = self._get_add_time_ids(
428
+ negative_original_size,
429
+ negative_crops_coords_top_left,
430
+ negative_target_size,
431
+ dtype=prompt_embeds.dtype,
432
+ text_encoder_projection_dim=text_encoder_projection_dim,
433
+ )
434
+ else:
435
+ negative_add_time_ids = add_time_ids
436
+
437
+ if self.do_classifier_free_guidance:
438
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
439
+ add_text_embeds = torch.cat(
440
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
441
+ )
442
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
443
+
444
+ prompt_embeds = prompt_embeds.to(device)
445
+ add_text_embeds = add_text_embeds.to(device)
446
+ add_time_ids = add_time_ids.to(device).repeat(
447
+ batch_size * num_images_per_prompt, 1
448
+ )
449
+
450
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
451
+ image_embeds = self.prepare_ip_adapter_image_embeds(
452
+ ip_adapter_image,
453
+ ip_adapter_image_embeds,
454
+ device,
455
+ batch_size * num_images_per_prompt,
456
+ self.do_classifier_free_guidance,
457
+ )
458
+
459
+ # Preprocess control image
460
+ control_image_feature = self.prepare_control_image(
461
+ image=control_image,
462
+ width=width,
463
+ height=height,
464
+ batch_size=batch_size * num_images_per_prompt,
465
+ num_images_per_prompt=1, # NOTE: always 1 for control images
466
+ device=device,
467
+ dtype=latents.dtype,
468
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
469
+ )
470
+ control_image_feature = control_image_feature.to(
471
+ device=device, dtype=latents.dtype
472
+ )
473
+
474
+ adapter_state = self.cond_encoder(control_image_feature)
475
+ for i, state in enumerate(adapter_state):
476
+ adapter_state[i] = state * control_conditioning_scale
477
+
478
+ # Preprocess controlnet image if provided
479
+ do_controlnet = controlnet_image is not None and hasattr(self, "controlnet")
480
+ if do_controlnet:
481
+ controlnet_image = self.prepare_control_image(
482
+ image=controlnet_image,
483
+ width=width,
484
+ height=height,
485
+ batch_size=batch_size * num_images_per_prompt,
486
+ num_images_per_prompt=1, # NOTE: always 1 for control images
487
+ device=device,
488
+ dtype=latents.dtype,
489
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
490
+ )
491
+ controlnet_image = controlnet_image.to(device=device, dtype=latents.dtype)
492
+
493
+ # 8. Denoising loop
494
+ num_warmup_steps = max(
495
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
496
+ )
497
+
498
+ # 8.1 Apply denoising_end
499
+ if (
500
+ self.denoising_end is not None
501
+ and isinstance(self.denoising_end, float)
502
+ and self.denoising_end > 0
503
+ and self.denoising_end < 1
504
+ ):
505
+ discrete_timestep_cutoff = int(
506
+ round(
507
+ self.scheduler.config.num_train_timesteps
508
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
509
+ )
510
+ )
511
+ num_inference_steps = len(
512
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
513
+ )
514
+ timesteps = timesteps[:num_inference_steps]
515
+
516
+ # 9. Optionally get Guidance Scale Embedding
517
+ timestep_cond = None
518
+ if self.unet.config.time_cond_proj_dim is not None:
519
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
520
+ batch_size * num_images_per_prompt
521
+ )
522
+ timestep_cond = self.get_guidance_scale_embedding(
523
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
524
+ ).to(device=device, dtype=latents.dtype)
525
+
526
+ self._num_timesteps = len(timesteps)
527
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
528
+ for i, t in enumerate(timesteps):
529
+ if self.interrupt:
530
+ continue
531
+
532
+ # expand the latents if we are doing classifier free guidance
533
+ latent_model_input = (
534
+ torch.cat([latents] * 2)
535
+ if self.do_classifier_free_guidance
536
+ else latents
537
+ )
538
+
539
+ latent_model_input = self.scheduler.scale_model_input(
540
+ latent_model_input, t
541
+ )
542
+
543
+ added_cond_kwargs = {
544
+ "text_embeds": add_text_embeds,
545
+ "time_ids": add_time_ids,
546
+ }
547
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
548
+ added_cond_kwargs["image_embeds"] = image_embeds
549
+
550
+ if i < int(num_inference_steps * control_conditioning_factor):
551
+ down_intrablock_additional_residuals = [
552
+ state.clone() for state in adapter_state
553
+ ]
554
+ else:
555
+ down_intrablock_additional_residuals = None
556
+
557
+ unet_add_kwargs = {}
558
+
559
+ # Do controlnet if provided
560
+ if do_controlnet:
561
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
562
+ latent_model_input,
563
+ t,
564
+ encoder_hidden_states=prompt_embeds,
565
+ controlnet_cond=controlnet_image,
566
+ conditioning_scale=controlnet_conditioning_scale,
567
+ guess_mode=False,
568
+ added_cond_kwargs=added_cond_kwargs,
569
+ return_dict=False,
570
+ )
571
+ unet_add_kwargs.update(
572
+ {
573
+ "down_block_additional_residuals": down_block_res_samples,
574
+ "mid_block_additional_residual": mid_block_res_sample,
575
+ }
576
+ )
577
+
578
+ # predict the noise residual
579
+ noise_pred = self.unet(
580
+ latent_model_input,
581
+ t,
582
+ encoder_hidden_states=prompt_embeds,
583
+ timestep_cond=timestep_cond,
584
+ cross_attention_kwargs={
585
+ "mv_scale": mv_scale,
586
+ **(self.cross_attention_kwargs or {}),
587
+ },
588
+ down_intrablock_additional_residuals=down_intrablock_additional_residuals,
589
+ added_cond_kwargs=added_cond_kwargs,
590
+ return_dict=False,
591
+ **unet_add_kwargs,
592
+ )[0]
593
+
594
+ # perform guidance
595
+ if self.do_classifier_free_guidance:
596
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
597
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
598
+ noise_pred_text - noise_pred_uncond
599
+ )
600
+
601
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
602
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
603
+ noise_pred = rescale_noise_cfg(
604
+ noise_pred,
605
+ noise_pred_text,
606
+ guidance_rescale=self.guidance_rescale,
607
+ )
608
+
609
+ # compute the previous noisy sample x_t -> x_t-1
610
+ latents_dtype = latents.dtype
611
+ latents = self.scheduler.step(
612
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
613
+ )[0]
614
+ if latents.dtype != latents_dtype:
615
+ if torch.backends.mps.is_available():
616
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
617
+ latents = latents.to(latents_dtype)
618
+
619
+ if callback_on_step_end is not None:
620
+ callback_kwargs = {}
621
+ for k in callback_on_step_end_tensor_inputs:
622
+ callback_kwargs[k] = locals()[k]
623
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
624
+
625
+ latents = callback_outputs.pop("latents", latents)
626
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
627
+ negative_prompt_embeds = callback_outputs.pop(
628
+ "negative_prompt_embeds", negative_prompt_embeds
629
+ )
630
+ add_text_embeds = callback_outputs.pop(
631
+ "add_text_embeds", add_text_embeds
632
+ )
633
+ negative_pooled_prompt_embeds = callback_outputs.pop(
634
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
635
+ )
636
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
637
+ negative_add_time_ids = callback_outputs.pop(
638
+ "negative_add_time_ids", negative_add_time_ids
639
+ )
640
+
641
+ # call the callback, if provided
642
+ if i == len(timesteps) - 1 or (
643
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
644
+ ):
645
+ progress_bar.update()
646
+ if callback is not None and i % callback_steps == 0:
647
+ step_idx = i // getattr(self.scheduler, "order", 1)
648
+ callback(step_idx, t, latents)
649
+
650
+ if not output_type == "latent":
651
+ # make sure the VAE is in float32 mode, as it overflows in float16
652
+ needs_upcasting = (
653
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
654
+ )
655
+
656
+ if needs_upcasting:
657
+ self.upcast_vae()
658
+ latents = latents.to(
659
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
660
+ )
661
+ elif latents.dtype != self.vae.dtype:
662
+ if torch.backends.mps.is_available():
663
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
664
+ self.vae = self.vae.to(latents.dtype)
665
+
666
+ # unscale/denormalize the latents
667
+ # denormalize with the mean and std if available and not None
668
+ has_latents_mean = (
669
+ hasattr(self.vae.config, "latents_mean")
670
+ and self.vae.config.latents_mean is not None
671
+ )
672
+ has_latents_std = (
673
+ hasattr(self.vae.config, "latents_std")
674
+ and self.vae.config.latents_std is not None
675
+ )
676
+ if has_latents_mean and has_latents_std:
677
+ latents_mean = (
678
+ torch.tensor(self.vae.config.latents_mean)
679
+ .view(1, 4, 1, 1)
680
+ .to(latents.device, latents.dtype)
681
+ )
682
+ latents_std = (
683
+ torch.tensor(self.vae.config.latents_std)
684
+ .view(1, 4, 1, 1)
685
+ .to(latents.device, latents.dtype)
686
+ )
687
+ latents = (
688
+ latents * latents_std / self.vae.config.scaling_factor
689
+ + latents_mean
690
+ )
691
+ else:
692
+ latents = latents / self.vae.config.scaling_factor
693
+
694
+ image = self.vae.decode(latents, return_dict=False)[0]
695
+
696
+ # cast back to fp16 if needed
697
+ if needs_upcasting:
698
+ self.vae.to(dtype=torch.float16)
699
+ else:
700
+ image = latents
701
+
702
+ if not output_type == "latent":
703
+ # apply watermark if available
704
+ if self.watermark is not None:
705
+ image = self.watermark.apply_watermark(image)
706
+
707
+ image = self.image_processor.postprocess(image, output_type=output_type)
708
+
709
+ # Offload all models
710
+ self.maybe_free_model_hooks()
711
+
712
+ if not return_dict:
713
+ return (image,)
714
+
715
+ return StableDiffusionXLPipelineOutput(images=image)
716
+
717
+ ### NEW: adapters ###
718
+ def _init_custom_adapter(
719
+ self,
720
+ # Multi-view adapter
721
+ num_views: int,
722
+ self_attn_processor: Any = DecoupledMVRowSelfAttnProcessor2_0,
723
+ # Condition encoder
724
+ cond_in_channels: int = 6,
725
+ # For training
726
+ copy_attn_weights: bool = True,
727
+ zero_init_module_keys: List[str] = [],
728
+ ):
729
+ # Condition encoder
730
+ self.cond_encoder = T2IAdapter(
731
+ in_channels=cond_in_channels,
732
+ channels=(320, 640, 1280, 1280),
733
+ num_res_blocks=2,
734
+ downscale_factor=16,
735
+ adapter_type="full_adapter_xl",
736
+ )
737
+
738
+ # set custom attn processor for multi-view attention
739
+ self.unet: UNet2DConditionModel
740
+ set_unet_2d_condition_attn_processor(
741
+ self.unet,
742
+ set_self_attn_proc_func=lambda name, hs, cad, ap: self_attn_processor(
743
+ query_dim=hs,
744
+ inner_dim=hs,
745
+ num_views=num_views,
746
+ name=name,
747
+ use_mv=True,
748
+ use_ref=False,
749
+ ),
750
+ )
751
+
752
+ # copy decoupled attention weights from original unet
753
+ if copy_attn_weights:
754
+ state_dict = self.unet.state_dict()
755
+ for key in state_dict.keys():
756
+ if "_mv" in key:
757
+ compatible_key = key.replace("_mv", "").replace("processor.", "")
758
+ else:
759
+ compatible_key = key
760
+
761
+ is_zero_init_key = any([k in key for k in zero_init_module_keys])
762
+ if is_zero_init_key:
763
+ state_dict[key] = torch.zeros_like(state_dict[compatible_key])
764
+ else:
765
+ state_dict[key] = state_dict[compatible_key].clone()
766
+ self.unet.load_state_dict(state_dict)
767
+
768
+ def _load_custom_adapter(self, state_dict):
769
+ self.unet.load_state_dict(state_dict, strict=False)
770
+ self.cond_encoder.load_state_dict(state_dict, strict=False)
771
+
772
+ def _save_custom_adapter(
773
+ self,
774
+ include_keys: Optional[List[str]] = None,
775
+ exclude_keys: Optional[List[str]] = None,
776
+ ):
777
+ def include_fn(k):
778
+ is_included = False
779
+
780
+ if include_keys is not None:
781
+ is_included = is_included or any([key in k for key in include_keys])
782
+ if exclude_keys is not None:
783
+ is_included = is_included and not any(
784
+ [key in k for key in exclude_keys]
785
+ )
786
+
787
+ return is_included
788
+
789
+ state_dict = {k: v for k, v in self.unet.state_dict().items() if include_fn(k)}
790
+ state_dict.update(self.cond_encoder.state_dict())
791
+
792
+ return state_dict
mvadapter/schedulers/scheduler_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None):
5
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
6
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
7
+ timesteps = timesteps.to(device)
8
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
9
+ sigma = sigmas[step_indices].flatten()
10
+ while len(sigma.shape) < n_dim:
11
+ sigma = sigma.unsqueeze(-1)
12
+ return sigma
13
+
14
+
15
+ def SNR_to_betas(snr):
16
+ """
17
+ Converts SNR to betas
18
+ """
19
+ # alphas_cumprod = pass
20
+ # snr = (alpha / ) ** 2
21
+ # alpha_t^2 / (1 - alpha_t^2) = snr
22
+ alpha_t = (snr / (1 + snr)) ** 0.5
23
+ alphas_cumprod = alpha_t**2
24
+ alphas = alphas_cumprod / torch.cat(
25
+ [torch.ones(1, device=snr.device), alphas_cumprod[:-1]]
26
+ )
27
+ betas = 1 - alphas
28
+ return betas
29
+
30
+
31
+ def compute_snr(timesteps, noise_scheduler):
32
+ """
33
+ Computes SNR as per Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
34
+ """
35
+ alphas_cumprod = noise_scheduler.alphas_cumprod
36
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
37
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
38
+
39
+ # Expand the tensors.
40
+ # Adapted from Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5
41
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
42
+ timesteps
43
+ ].float()
44
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
45
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
46
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
47
+
48
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
49
+ device=timesteps.device
50
+ )[timesteps].float()
51
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
52
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
53
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
54
+
55
+ # Compute SNR.
56
+ snr = (alpha / sigma) ** 2
57
+ return snr
58
+
59
+
60
+ def compute_alpha(timesteps, noise_scheduler):
61
+ alphas_cumprod = noise_scheduler.alphas_cumprod
62
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
63
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
64
+ timesteps
65
+ ].float()
66
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
67
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
68
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
69
+
70
+ return alpha
mvadapter/schedulers/scheduling_shift_snr.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from .scheduler_utils import SNR_to_betas, compute_snr
6
+
7
+
8
+ class ShiftSNRScheduler:
9
+ def __init__(
10
+ self,
11
+ noise_scheduler: Any,
12
+ timesteps: Any,
13
+ shift_scale: float,
14
+ scheduler_class: Any,
15
+ ):
16
+ self.noise_scheduler = noise_scheduler
17
+ self.timesteps = timesteps
18
+ self.shift_scale = shift_scale
19
+ self.scheduler_class = scheduler_class
20
+
21
+ def _get_shift_scheduler(self):
22
+ """
23
+ Prepare scheduler for shifted betas.
24
+
25
+ :return: A scheduler object configured with shifted betas
26
+ """
27
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
28
+ shifted_betas = SNR_to_betas(snr / self.shift_scale)
29
+
30
+ return self.scheduler_class.from_config(
31
+ self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
32
+ )
33
+
34
+ def _get_interpolated_shift_scheduler(self):
35
+ """
36
+ Prepare scheduler for shifted betas and interpolate with the original betas in log space.
37
+
38
+ :return: A scheduler object configured with interpolated shifted betas
39
+ """
40
+ snr = compute_snr(self.timesteps, self.noise_scheduler)
41
+ shifted_snr = snr / self.shift_scale
42
+
43
+ weighting = self.timesteps.float() / (
44
+ self.noise_scheduler.config.num_train_timesteps - 1
45
+ )
46
+ interpolated_snr = torch.exp(
47
+ torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting
48
+ )
49
+
50
+ shifted_betas = SNR_to_betas(interpolated_snr)
51
+
52
+ return self.scheduler_class.from_config(
53
+ self.noise_scheduler.config, trained_betas=shifted_betas.numpy()
54
+ )
55
+
56
+ @classmethod
57
+ def from_scheduler(
58
+ cls,
59
+ noise_scheduler: Any,
60
+ shift_mode: str = "default",
61
+ timesteps: Any = None,
62
+ shift_scale: float = 1.0,
63
+ scheduler_class: Any = None,
64
+ ):
65
+ # Check input
66
+ if timesteps is None:
67
+ timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps)
68
+ if scheduler_class is None:
69
+ scheduler_class = noise_scheduler.__class__
70
+
71
+ # Create scheduler
72
+ shift_scheduler = cls(
73
+ noise_scheduler=noise_scheduler,
74
+ timesteps=timesteps,
75
+ shift_scale=shift_scale,
76
+ scheduler_class=scheduler_class,
77
+ )
78
+
79
+ if shift_mode == "default":
80
+ return shift_scheduler._get_shift_scheduler()
81
+ elif shift_mode == "interpolated":
82
+ return shift_scheduler._get_interpolated_shift_scheduler()
83
+ else:
84
+ raise ValueError(f"Unknown shift_mode: {shift_mode}")
85
+
86
+
87
+ if __name__ == "__main__":
88
+ """
89
+ Compare the alpha values for different noise schedulers.
90
+ """
91
+ import matplotlib.pyplot as plt
92
+ from diffusers import DDPMScheduler
93
+
94
+ from .scheduler_utils import compute_alpha
95
+
96
+ # Base
97
+ timesteps = torch.arange(0, 1000)
98
+ noise_scheduler_base = DDPMScheduler.from_pretrained(
99
+ "runwayml/stable-diffusion-v1-5", subfolder="scheduler"
100
+ )
101
+ alpha = compute_alpha(timesteps, noise_scheduler_base)
102
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Base")
103
+
104
+ # Kolors
105
+ num_train_timesteps_ = 1100
106
+ timesteps_ = torch.arange(0, num_train_timesteps_)
107
+ noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_}
108
+ noise_scheduler_kolors = DDPMScheduler.from_config(
109
+ noise_scheduler_base.config, **noise_kwargs
110
+ )
111
+ alpha = compute_alpha(timesteps_, noise_scheduler_kolors)
112
+ plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors")
113
+
114
+ # Shift betas
115
+ shift_scale = 8.0
116
+ noise_scheduler_shift = ShiftSNRScheduler.from_scheduler(
117
+ noise_scheduler_base, shift_mode="default", shift_scale=shift_scale
118
+ )
119
+ alpha = compute_alpha(timesteps, noise_scheduler_shift)
120
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)")
121
+
122
+ # Shift betas (interpolated)
123
+ noise_scheduler_inter = ShiftSNRScheduler.from_scheduler(
124
+ noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale
125
+ )
126
+ alpha = compute_alpha(timesteps, noise_scheduler_inter)
127
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)")
128
+
129
+ # ZeroSNR
130
+ noise_scheduler = DDPMScheduler.from_config(
131
+ noise_scheduler_base.config, rescale_betas_zero_snr=True
132
+ )
133
+ alpha = compute_alpha(timesteps, noise_scheduler)
134
+ plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR")
135
+
136
+ plt.legend()
137
+ plt.grid()
138
+ plt.savefig("check_alpha.png")
mvadapter/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .camera import get_camera, get_orthogonal_camera
2
+ from .geometry import get_plucker_embeds_from_cameras_ortho
3
+ from .saving import make_image_grid, tensor_to_image
mvadapter/utils/camera.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ # import trimesh
9
+
10
+
11
+ from PIL import Image
12
+ from torch import BoolTensor, FloatTensor
13
+
14
+ LIST_TYPE = Union[list, np.ndarray, torch.Tensor]
15
+
16
+
17
+ def list_to_pt(
18
+ x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None
19
+ ) -> torch.Tensor:
20
+ if isinstance(x, list) or isinstance(x, np.ndarray):
21
+ return torch.tensor(x, dtype=dtype, device=device)
22
+ return x.to(dtype=dtype)
23
+
24
+
25
+ def get_c2w(
26
+ elevation_deg: LIST_TYPE,
27
+ distance: LIST_TYPE,
28
+ azimuth_deg: Optional[LIST_TYPE],
29
+ num_views: Optional[int] = 1,
30
+ device: Optional[str] = None,
31
+ ) -> torch.FloatTensor:
32
+ if azimuth_deg is None:
33
+ assert (
34
+ num_views is not None
35
+ ), "num_views must be provided if azimuth_deg is None."
36
+ azimuth_deg = torch.linspace(
37
+ 0, 360, num_views + 1, dtype=torch.float32, device=device
38
+ )[:-1]
39
+ else:
40
+ num_views = len(azimuth_deg)
41
+ azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device)
42
+ elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device)
43
+ camera_distances = list_to_pt(distance, dtype=torch.float32, device=device)
44
+ elevation = elevation_deg * math.pi / 180
45
+ azimuth = azimuth_deg * math.pi / 180
46
+ camera_positions = torch.stack(
47
+ [
48
+ camera_distances * torch.cos(elevation) * torch.cos(azimuth),
49
+ camera_distances * torch.cos(elevation) * torch.sin(azimuth),
50
+ camera_distances * torch.sin(elevation),
51
+ ],
52
+ dim=-1,
53
+ )
54
+ center = torch.zeros_like(camera_positions)
55
+ up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat(
56
+ num_views, 1
57
+ )
58
+ lookat = F.normalize(center - camera_positions, dim=-1)
59
+ right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1)
60
+ up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1)
61
+ c2w3x4 = torch.cat(
62
+ [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
63
+ dim=-1,
64
+ )
65
+ c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
66
+ c2w[:, 3, 3] = 1.0
67
+ return c2w
68
+
69
+
70
+ def get_projection_matrix(
71
+ fovy_deg: LIST_TYPE,
72
+ aspect_wh: float = 1.0,
73
+ near: float = 0.1,
74
+ far: float = 100.0,
75
+ device: Optional[str] = None,
76
+ ) -> torch.FloatTensor:
77
+ fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device)
78
+ batch_size = fovy_deg.shape[0]
79
+ fovy = fovy_deg * math.pi / 180
80
+ tan_half_fovy = torch.tan(fovy / 2)
81
+ projection_matrix = torch.zeros(
82
+ batch_size, 4, 4, dtype=torch.float32, device=device
83
+ )
84
+ projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy)
85
+ projection_matrix[:, 1, 1] = -1 / tan_half_fovy
86
+ projection_matrix[:, 2, 2] = -(far + near) / (far - near)
87
+ projection_matrix[:, 2, 3] = -2 * far * near / (far - near)
88
+ projection_matrix[:, 3, 2] = -1
89
+ return projection_matrix
90
+
91
+
92
+ def get_orthogonal_projection_matrix(
93
+ batch_size: int,
94
+ left: float,
95
+ right: float,
96
+ bottom: float,
97
+ top: float,
98
+ near: float = 0.1,
99
+ far: float = 100.0,
100
+ device: Optional[str] = None,
101
+ ) -> torch.FloatTensor:
102
+ projection_matrix = torch.zeros(
103
+ batch_size, 4, 4, dtype=torch.float32, device=device
104
+ )
105
+ projection_matrix[:, 0, 0] = 2 / (right - left)
106
+ projection_matrix[:, 1, 1] = -2 / (top - bottom)
107
+ projection_matrix[:, 2, 2] = -2 / (far - near)
108
+ projection_matrix[:, 0, 3] = -(right + left) / (right - left)
109
+ projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom)
110
+ projection_matrix[:, 2, 3] = -(far + near) / (far - near)
111
+ projection_matrix[:, 3, 3] = 1
112
+ return projection_matrix
113
+
114
+
115
+ @dataclass
116
+ class Camera:
117
+ c2w: Optional[torch.FloatTensor]
118
+ w2c: torch.FloatTensor
119
+ proj_mtx: torch.FloatTensor
120
+ mvp_mtx: torch.FloatTensor
121
+ cam_pos: Optional[torch.FloatTensor]
122
+
123
+ def __getitem__(self, index):
124
+ if isinstance(index, int):
125
+ sl = slice(index, index + 1)
126
+ elif isinstance(index, slice):
127
+ sl = index
128
+ else:
129
+ raise NotImplementedError
130
+
131
+ return Camera(
132
+ c2w=self.c2w[sl] if self.c2w is not None else None,
133
+ w2c=self.w2c[sl],
134
+ proj_mtx=self.proj_mtx[sl],
135
+ mvp_mtx=self.mvp_mtx[sl],
136
+ cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None,
137
+ )
138
+
139
+ def to(self, device: Optional[str] = None):
140
+ if self.c2w is not None:
141
+ self.c2w = self.c2w.to(device)
142
+ self.w2c = self.w2c.to(device)
143
+ self.proj_mtx = self.proj_mtx.to(device)
144
+ self.mvp_mtx = self.mvp_mtx.to(device)
145
+ if self.cam_pos is not None:
146
+ self.cam_pos = self.cam_pos.to(device)
147
+
148
+ def __len__(self):
149
+ return self.c2w.shape[0]
150
+
151
+
152
+ def get_camera(
153
+ elevation_deg: Optional[LIST_TYPE] = None,
154
+ distance: Optional[LIST_TYPE] = None,
155
+ fovy_deg: Optional[LIST_TYPE] = None,
156
+ azimuth_deg: Optional[LIST_TYPE] = None,
157
+ num_views: Optional[int] = 1,
158
+ c2w: Optional[torch.FloatTensor] = None,
159
+ w2c: Optional[torch.FloatTensor] = None,
160
+ proj_mtx: Optional[torch.FloatTensor] = None,
161
+ aspect_wh: float = 1.0,
162
+ near: float = 0.1,
163
+ far: float = 100.0,
164
+ device: Optional[str] = None,
165
+ ):
166
+ if w2c is None:
167
+ if c2w is None:
168
+ c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device)
169
+ camera_positions = c2w[:, :3, 3]
170
+ w2c = torch.linalg.inv(c2w)
171
+ else:
172
+ camera_positions = None
173
+ c2w = None
174
+ if proj_mtx is None:
175
+ proj_mtx = get_projection_matrix(
176
+ fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device
177
+ )
178
+ mvp_mtx = proj_mtx @ w2c
179
+ return Camera(
180
+ c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions
181
+ )
182
+
183
+
184
+ def get_orthogonal_camera(
185
+ elevation_deg: LIST_TYPE,
186
+ distance: LIST_TYPE,
187
+ left: float,
188
+ right: float,
189
+ bottom: float,
190
+ top: float,
191
+ azimuth_deg: Optional[LIST_TYPE] = None,
192
+ num_views: Optional[int] = 1,
193
+ near: float = 0.1,
194
+ far: float = 100.0,
195
+ device: Optional[str] = None,
196
+ ):
197
+ c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device)
198
+ camera_positions = c2w[:, :3, 3]
199
+ w2c = torch.linalg.inv(c2w)
200
+ proj_mtx = get_orthogonal_projection_matrix(
201
+ batch_size=c2w.shape[0],
202
+ left=left,
203
+ right=right,
204
+ bottom=bottom,
205
+ top=top,
206
+ near=near,
207
+ far=far,
208
+ device=device,
209
+ )
210
+ mvp_mtx = proj_mtx @ w2c
211
+ return Camera(
212
+ c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions
213
+ )
mvadapter/utils/geometry.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def get_position_map_from_depth(depth, mask, intrinsics, extrinsics, image_wh=None):
9
+ """Compute the position map from the depth map and the camera parameters for a batch of views.
10
+
11
+ Args:
12
+ depth (torch.Tensor): The depth maps with the shape (B, H, W, 1).
13
+ mask (torch.Tensor): The masks with the shape (B, H, W, 1).
14
+ intrinsics (torch.Tensor): The camera intrinsics matrices with the shape (B, 3, 3).
15
+ extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4).
16
+ image_wh (Tuple[int, int]): The image width and height.
17
+
18
+ Returns:
19
+ torch.Tensor: The position maps with the shape (B, H, W, 3).
20
+ """
21
+ if image_wh is None:
22
+ image_wh = depth.shape[2], depth.shape[1]
23
+
24
+ B, H, W, _ = depth.shape
25
+ depth = depth.squeeze(-1)
26
+
27
+ u_coord, v_coord = torch.meshgrid(
28
+ torch.arange(image_wh[0]), torch.arange(image_wh[1]), indexing="xy"
29
+ )
30
+ u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
31
+ v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
32
+
33
+ # Compute the position map by back-projecting depth pixels to 3D space
34
+ x = (
35
+ (u_coord - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1))
36
+ * depth
37
+ / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1)
38
+ )
39
+ y = (
40
+ (v_coord - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1))
41
+ * depth
42
+ / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1)
43
+ )
44
+ z = depth
45
+
46
+ # Concatenate to form the 3D coordinates in the camera frame
47
+ camera_coords = torch.stack([x, y, z], dim=-1)
48
+
49
+ # Apply the extrinsic matrix to get coordinates in the world frame
50
+ coords_homogeneous = torch.nn.functional.pad(
51
+ camera_coords, (0, 1), "constant", 1.0
52
+ ) # Add a homogeneous coordinate
53
+ world_coords = torch.matmul(
54
+ coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2)
55
+ ).view(B, H, W, 4)
56
+
57
+ # Apply the mask to the position map
58
+ position_map = world_coords[..., :3] * mask
59
+
60
+ return position_map
61
+
62
+
63
+ def get_position_map_from_depth_ortho(
64
+ depth, mask, extrinsics, ortho_scale, image_wh=None
65
+ ):
66
+ """Compute the position map from the depth map and the camera parameters for a batch of views
67
+ using orthographic projection with a given ortho_scale.
68
+
69
+ Args:
70
+ depth (torch.Tensor): The depth maps with the shape (B, H, W, 1).
71
+ mask (torch.Tensor): The masks with the shape (B, H, W, 1).
72
+ extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4).
73
+ ortho_scale (torch.Tensor): The scaling factor for the orthographic projection with the shape (B, 1, 1, 1).
74
+ image_wh (Tuple[int, int]): Optional. The image width and height.
75
+
76
+ Returns:
77
+ torch.Tensor: The position maps with the shape (B, H, W, 3).
78
+ """
79
+ if image_wh is None:
80
+ image_wh = depth.shape[2], depth.shape[1]
81
+
82
+ B, H, W, _ = depth.shape
83
+ depth = depth.squeeze(-1)
84
+
85
+ # Generating grid of coordinates in the image space
86
+ u_coord, v_coord = torch.meshgrid(
87
+ torch.arange(0, image_wh[0]), torch.arange(0, image_wh[1]), indexing="xy"
88
+ )
89
+ u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
90
+ v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
91
+
92
+ # Compute the position map using orthographic projection with ortho_scale
93
+ x = (u_coord - image_wh[0] / 2) / ortho_scale / image_wh[0]
94
+ y = (v_coord - image_wh[1] / 2) / ortho_scale / image_wh[1]
95
+ z = depth
96
+
97
+ # Concatenate to form the 3D coordinates in the camera frame
98
+ camera_coords = torch.stack([x, y, z], dim=-1)
99
+
100
+ # Apply the extrinsic matrix to get coordinates in the world frame
101
+ coords_homogeneous = torch.nn.functional.pad(
102
+ camera_coords, (0, 1), "constant", 1.0
103
+ ) # Add a homogeneous coordinate
104
+ world_coords = torch.matmul(
105
+ coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2)
106
+ ).view(B, H, W, 4)
107
+
108
+ # Apply the mask to the position map
109
+ position_map = world_coords[..., :3] * mask
110
+
111
+ return position_map
112
+
113
+
114
+ def get_opencv_from_blender(matrix_world, fov=None, image_size=None):
115
+ # convert matrix_world to opencv format extrinsics
116
+ opencv_world_to_cam = matrix_world.inverse()
117
+ opencv_world_to_cam[1, :] *= -1
118
+ opencv_world_to_cam[2, :] *= -1
119
+ R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3]
120
+
121
+ if fov is None: # orthographic camera
122
+ return R, T
123
+
124
+ R, T = R.unsqueeze(0), T.unsqueeze(0)
125
+ # convert fov to opencv format intrinsics
126
+ focal = 1 / np.tan(fov / 2)
127
+ intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32)
128
+ opencv_cam_matrix = (
129
+ torch.from_numpy(intrinsics).unsqueeze(0).float().to(matrix_world.device)
130
+ )
131
+ opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2]).to(
132
+ matrix_world.device
133
+ )
134
+ opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2
135
+
136
+ return R, T, opencv_cam_matrix
137
+
138
+
139
+ def get_ray_directions(
140
+ H: int,
141
+ W: int,
142
+ focal: float,
143
+ principal: Optional[Tuple[float, float]] = None,
144
+ use_pixel_centers: bool = True,
145
+ ) -> torch.Tensor:
146
+ """
147
+ Get ray directions for all pixels in camera coordinate.
148
+ Args:
149
+ H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
150
+ Outputs:
151
+ directions: (H, W, 3), the direction of the rays in camera coordinate
152
+ """
153
+ pixel_center = 0.5 if use_pixel_centers else 0
154
+ cx, cy = W / 2, H / 2 if principal is None else principal
155
+ i, j = torch.meshgrid(
156
+ torch.arange(W, dtype=torch.float32) + pixel_center,
157
+ torch.arange(H, dtype=torch.float32) + pixel_center,
158
+ indexing="xy",
159
+ )
160
+ directions = torch.stack(
161
+ [(i - cx) / focal, -(j - cy) / focal, -torch.ones_like(i)], -1
162
+ )
163
+ return F.normalize(directions, dim=-1)
164
+
165
+
166
+ def get_rays(
167
+ directions: torch.Tensor, c2w: torch.Tensor
168
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
169
+ """
170
+ Get ray origins and directions from camera coordinates to world coordinates
171
+ Args:
172
+ directions: (H, W, 3) ray directions in camera coordinates
173
+ c2w: (4, 4) camera-to-world transformation matrix
174
+ Outputs:
175
+ rays_o, rays_d: (H, W, 3) ray origins and directions in world coordinates
176
+ """
177
+ # Rotate ray directions from camera coordinate to the world coordinate
178
+ rays_d = directions @ c2w[:3, :3].T
179
+ rays_o = c2w[:3, 3].expand(rays_d.shape)
180
+ return rays_o, rays_d
181
+
182
+
183
+ def compute_plucker_embed(
184
+ c2w: torch.Tensor, image_width: int, image_height: int, focal: float
185
+ ) -> torch.Tensor:
186
+ """
187
+ Computes Plucker coordinates for a camera.
188
+ Args:
189
+ c2w: (4, 4) camera-to-world transformation matrix
190
+ image_width: Image width
191
+ image_height: Image height
192
+ focal: Focal length of the camera
193
+ Returns:
194
+ plucker: (6, H, W) Plucker embedding
195
+ """
196
+ directions = get_ray_directions(image_height, image_width, focal)
197
+ rays_o, rays_d = get_rays(directions, c2w)
198
+ # Cross product to get Plucker coordinates
199
+ cross = torch.cross(rays_o, rays_d, dim=-1)
200
+ plucker = torch.cat((rays_d, cross), dim=-1)
201
+ return plucker.permute(2, 0, 1)
202
+
203
+
204
+ def get_plucker_embeds_from_cameras(
205
+ c2w: List[torch.Tensor], fov: List[float], image_size: int
206
+ ) -> torch.Tensor:
207
+ """
208
+ Given lists of camera transformations and fov, returns the batched plucker embeddings.
209
+ Args:
210
+ c2w: list of camera-to-world transformation matrices
211
+ fov: list of field of view values
212
+ image_size: size of the image
213
+ Returns:
214
+ plucker_embeds: (B, 6, H, W) batched plucker embeddings
215
+ """
216
+ plucker_embeds = []
217
+ for cam_matrix, cam_fov in zip(c2w, fov):
218
+ focal = 0.5 * image_size / np.tan(0.5 * cam_fov)
219
+ plucker = compute_plucker_embed(cam_matrix, image_size, image_size, focal)
220
+ plucker_embeds.append(plucker)
221
+ return torch.stack(plucker_embeds)
222
+
223
+
224
+ def get_plucker_embeds_from_cameras_ortho(
225
+ c2w: List[torch.Tensor], ortho_scale: List[float], image_size: int
226
+ ):
227
+ """
228
+ Given lists of camera transformations and fov, returns the batched plucker embeddings.
229
+
230
+ Parameters:
231
+ c2w: list of camera-to-world transformation matrices
232
+ fov: list of field of view values
233
+ image_size: size of the image
234
+
235
+ Returns:
236
+ plucker_embeds: plucker embeddings (B, 6, H, W)
237
+ """
238
+ plucker_embeds = []
239
+ # compute pairwise mask and plucker embeddings
240
+ for cam_matrix, scale in zip(c2w, ortho_scale):
241
+ # blender to opencv to pytorch3d
242
+ R, T = get_opencv_from_blender(cam_matrix)
243
+ cam_pos = -R.T @ T
244
+ view_dir = R.T @ torch.tensor([0, 0, 1]).float().to(cam_matrix.device)
245
+ # normalize camera position
246
+ cam_pos = F.normalize(cam_pos, dim=0)
247
+ plucker = torch.concat([view_dir, cam_pos])
248
+ plucker = plucker.unsqueeze(-1).unsqueeze(-1).repeat(1, image_size, image_size)
249
+ plucker_embeds.append(plucker)
250
+
251
+ plucker_embeds = torch.stack(plucker_embeds)
252
+
253
+ return plucker_embeds
mvadapter/utils/saving.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+
8
+
9
+ def tensor_to_image(
10
+ data: Union[Image.Image, torch.Tensor, np.ndarray],
11
+ batched: bool = False,
12
+ format: str = "HWC",
13
+ ) -> Union[Image.Image, List[Image.Image]]:
14
+ if isinstance(data, Image.Image):
15
+ return data
16
+ if isinstance(data, torch.Tensor):
17
+ data = data.detach().cpu().numpy()
18
+ if data.dtype == np.float32 or data.dtype == np.float16:
19
+ data = (data * 255).astype(np.uint8)
20
+ elif data.dtype == np.bool_:
21
+ data = data.astype(np.uint8) * 255
22
+ assert data.dtype == np.uint8
23
+ if format == "CHW":
24
+ if batched and data.ndim == 4:
25
+ data = data.transpose((0, 2, 3, 1))
26
+ elif not batched and data.ndim == 3:
27
+ data = data.transpose((1, 2, 0))
28
+
29
+ if batched:
30
+ return [Image.fromarray(d) for d in data]
31
+ return Image.fromarray(data)
32
+
33
+
34
+ def largest_factor_near_sqrt(n: int) -> int:
35
+ """
36
+ Finds the largest factor of n that is closest to the square root of n.
37
+
38
+ Args:
39
+ n (int): The integer for which to find the largest factor near its square root.
40
+
41
+ Returns:
42
+ int: The largest factor of n that is closest to the square root of n.
43
+ """
44
+ sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root
45
+
46
+ # First, check if the square root itself is a factor
47
+ if sqrt_n * sqrt_n == n:
48
+ return sqrt_n
49
+
50
+ # Otherwise, find the largest factor by iterating from sqrt_n downwards
51
+ for i in range(sqrt_n, 0, -1):
52
+ if n % i == 0:
53
+ return i
54
+
55
+ # If n is 1, return 1
56
+ return 1
57
+
58
+
59
+ def make_image_grid(
60
+ images: List[Image.Image],
61
+ rows: Optional[int] = None,
62
+ cols: Optional[int] = None,
63
+ resize: Optional[int] = None,
64
+ ) -> Image.Image:
65
+ """
66
+ Prepares a single grid of images. Useful for visualization purposes.
67
+ """
68
+ if rows is None and cols is not None:
69
+ assert len(images) % cols == 0
70
+ rows = len(images) // cols
71
+ elif cols is None and rows is not None:
72
+ assert len(images) % rows == 0
73
+ cols = len(images) // rows
74
+ elif rows is None and cols is None:
75
+ rows = largest_factor_near_sqrt(len(images))
76
+ cols = len(images) // rows
77
+
78
+ assert len(images) == rows * cols
79
+
80
+ if resize is not None:
81
+ images = [img.resize((resize, resize)) for img in images]
82
+
83
+ w, h = images[0].size
84
+ grid = Image.new("RGB", size=(cols * w, rows * h))
85
+
86
+ for i, img in enumerate(images):
87
+ grid.paste(img, box=(i % cols * w, i // cols * h))
88
+ return grid