multimodalart HF staff commited on
Commit
69c26b8
1 Parent(s): 974def6

Upload folder using huggingface_hub

Browse files
controlnet_flux.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.loaders import PeftAdapterMixin
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.models.attention_processor import AttentionProcessor
11
+ from diffusers.utils import (
12
+ USE_PEFT_BACKEND,
13
+ is_torch_version,
14
+ logging,
15
+ scale_lora_layers,
16
+ unscale_lora_layers,
17
+ )
18
+ from diffusers.models.controlnet import BaseOutput, zero_module
19
+ from diffusers.models.embeddings import (
20
+ CombinedTimestepGuidanceTextProjEmbeddings,
21
+ CombinedTimestepTextProjEmbeddings,
22
+ )
23
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
24
+ from transformer_flux import (
25
+ EmbedND,
26
+ FluxSingleTransformerBlock,
27
+ FluxTransformerBlock,
28
+ )
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class FluxControlNetOutput(BaseOutput):
36
+ controlnet_block_samples: Tuple[torch.Tensor]
37
+ controlnet_single_block_samples: Tuple[torch.Tensor]
38
+
39
+
40
+ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
41
+ _supports_gradient_checkpointing = True
42
+
43
+ @register_to_config
44
+ def __init__(
45
+ self,
46
+ patch_size: int = 1,
47
+ in_channels: int = 64,
48
+ num_layers: int = 19,
49
+ num_single_layers: int = 38,
50
+ attention_head_dim: int = 128,
51
+ num_attention_heads: int = 24,
52
+ joint_attention_dim: int = 4096,
53
+ pooled_projection_dim: int = 768,
54
+ guidance_embeds: bool = False,
55
+ axes_dims_rope: List[int] = [16, 56, 56],
56
+ extra_condition_channels: int = 1 * 4,
57
+ ):
58
+ super().__init__()
59
+ self.out_channels = in_channels
60
+ self.inner_dim = num_attention_heads * attention_head_dim
61
+
62
+ self.pos_embed = EmbedND(
63
+ dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
64
+ )
65
+ text_time_guidance_cls = (
66
+ CombinedTimestepGuidanceTextProjEmbeddings
67
+ if guidance_embeds
68
+ else CombinedTimestepTextProjEmbeddings
69
+ )
70
+ self.time_text_embed = text_time_guidance_cls(
71
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
72
+ )
73
+
74
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
75
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
76
+
77
+ self.transformer_blocks = nn.ModuleList(
78
+ [
79
+ FluxTransformerBlock(
80
+ dim=self.inner_dim,
81
+ num_attention_heads=num_attention_heads,
82
+ attention_head_dim=attention_head_dim,
83
+ )
84
+ for _ in range(num_layers)
85
+ ]
86
+ )
87
+
88
+ self.single_transformer_blocks = nn.ModuleList(
89
+ [
90
+ FluxSingleTransformerBlock(
91
+ dim=self.inner_dim,
92
+ num_attention_heads=num_attention_heads,
93
+ attention_head_dim=attention_head_dim,
94
+ )
95
+ for _ in range(num_single_layers)
96
+ ]
97
+ )
98
+
99
+ # controlnet_blocks
100
+ self.controlnet_blocks = nn.ModuleList([])
101
+ for _ in range(len(self.transformer_blocks)):
102
+ self.controlnet_blocks.append(
103
+ zero_module(nn.Linear(self.inner_dim, self.inner_dim))
104
+ )
105
+
106
+ self.controlnet_single_blocks = nn.ModuleList([])
107
+ for _ in range(len(self.single_transformer_blocks)):
108
+ self.controlnet_single_blocks.append(
109
+ zero_module(nn.Linear(self.inner_dim, self.inner_dim))
110
+ )
111
+
112
+ self.controlnet_x_embedder = zero_module(
113
+ torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
114
+ )
115
+
116
+ self.gradient_checkpointing = False
117
+
118
+ @property
119
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
120
+ def attn_processors(self):
121
+ r"""
122
+ Returns:
123
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
124
+ indexed by its weight name.
125
+ """
126
+ # set recursively
127
+ processors = {}
128
+
129
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
130
+ if hasattr(module, "get_processor"):
131
+ processors[f"{name}.processor"] = module.get_processor()
132
+
133
+ for sub_name, child in module.named_children():
134
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
135
+
136
+ return processors
137
+
138
+ for name, module in self.named_children():
139
+ fn_recursive_add_processors(name, module, processors)
140
+
141
+ return processors
142
+
143
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
144
+ def set_attn_processor(self, processor):
145
+ r"""
146
+ Sets the attention processor to use to compute attention.
147
+
148
+ Parameters:
149
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
150
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
151
+ for **all** `Attention` layers.
152
+
153
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
154
+ processor. This is strongly recommended when setting trainable attention processors.
155
+
156
+ """
157
+ count = len(self.attn_processors.keys())
158
+
159
+ if isinstance(processor, dict) and len(processor) != count:
160
+ raise ValueError(
161
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
162
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
163
+ )
164
+
165
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
166
+ if hasattr(module, "set_processor"):
167
+ if not isinstance(processor, dict):
168
+ module.set_processor(processor)
169
+ else:
170
+ module.set_processor(processor.pop(f"{name}.processor"))
171
+
172
+ for sub_name, child in module.named_children():
173
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
174
+
175
+ for name, module in self.named_children():
176
+ fn_recursive_attn_processor(name, module, processor)
177
+
178
+ def _set_gradient_checkpointing(self, module, value=False):
179
+ if hasattr(module, "gradient_checkpointing"):
180
+ module.gradient_checkpointing = value
181
+
182
+ @classmethod
183
+ def from_transformer(
184
+ cls,
185
+ transformer,
186
+ num_layers: int = 4,
187
+ num_single_layers: int = 10,
188
+ attention_head_dim: int = 128,
189
+ num_attention_heads: int = 24,
190
+ load_weights_from_transformer=True,
191
+ ):
192
+ config = transformer.config
193
+ config["num_layers"] = num_layers
194
+ config["num_single_layers"] = num_single_layers
195
+ config["attention_head_dim"] = attention_head_dim
196
+ config["num_attention_heads"] = num_attention_heads
197
+
198
+ controlnet = cls(**config)
199
+
200
+ if load_weights_from_transformer:
201
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
202
+ controlnet.time_text_embed.load_state_dict(
203
+ transformer.time_text_embed.state_dict()
204
+ )
205
+ controlnet.context_embedder.load_state_dict(
206
+ transformer.context_embedder.state_dict()
207
+ )
208
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
209
+ controlnet.transformer_blocks.load_state_dict(
210
+ transformer.transformer_blocks.state_dict(), strict=False
211
+ )
212
+ controlnet.single_transformer_blocks.load_state_dict(
213
+ transformer.single_transformer_blocks.state_dict(), strict=False
214
+ )
215
+
216
+ controlnet.controlnet_x_embedder = zero_module(
217
+ controlnet.controlnet_x_embedder
218
+ )
219
+
220
+ return controlnet
221
+
222
+ def forward(
223
+ self,
224
+ hidden_states: torch.Tensor,
225
+ controlnet_cond: torch.Tensor,
226
+ conditioning_scale: float = 1.0,
227
+ encoder_hidden_states: torch.Tensor = None,
228
+ pooled_projections: torch.Tensor = None,
229
+ timestep: torch.LongTensor = None,
230
+ img_ids: torch.Tensor = None,
231
+ txt_ids: torch.Tensor = None,
232
+ guidance: torch.Tensor = None,
233
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
234
+ return_dict: bool = True,
235
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
236
+ """
237
+ The [`FluxTransformer2DModel`] forward method.
238
+
239
+ Args:
240
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
241
+ Input `hidden_states`.
242
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
243
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
244
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
245
+ from the embeddings of input conditions.
246
+ timestep ( `torch.LongTensor`):
247
+ Used to indicate denoising step.
248
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
249
+ A list of tensors that if specified are added to the residuals of transformer blocks.
250
+ joint_attention_kwargs (`dict`, *optional*):
251
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
252
+ `self.processor` in
253
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
254
+ return_dict (`bool`, *optional*, defaults to `True`):
255
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
256
+ tuple.
257
+
258
+ Returns:
259
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
260
+ `tuple` where the first element is the sample tensor.
261
+ """
262
+ if joint_attention_kwargs is not None:
263
+ joint_attention_kwargs = joint_attention_kwargs.copy()
264
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
265
+ else:
266
+ lora_scale = 1.0
267
+
268
+ if USE_PEFT_BACKEND:
269
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
270
+ scale_lora_layers(self, lora_scale)
271
+ else:
272
+ if (
273
+ joint_attention_kwargs is not None
274
+ and joint_attention_kwargs.get("scale", None) is not None
275
+ ):
276
+ logger.warning(
277
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
278
+ )
279
+ hidden_states = self.x_embedder(hidden_states)
280
+
281
+ # add condition
282
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
283
+
284
+ timestep = timestep.to(hidden_states.dtype) * 1000
285
+ if guidance is not None:
286
+ guidance = guidance.to(hidden_states.dtype) * 1000
287
+ else:
288
+ guidance = None
289
+ temb = (
290
+ self.time_text_embed(timestep, pooled_projections)
291
+ if guidance is None
292
+ else self.time_text_embed(timestep, guidance, pooled_projections)
293
+ )
294
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
295
+
296
+ txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
297
+ ids = torch.cat((txt_ids, img_ids), dim=1)
298
+ image_rotary_emb = self.pos_embed(ids)
299
+
300
+ block_samples = ()
301
+ for _, block in enumerate(self.transformer_blocks):
302
+ if self.training and self.gradient_checkpointing:
303
+
304
+ def create_custom_forward(module, return_dict=None):
305
+ def custom_forward(*inputs):
306
+ if return_dict is not None:
307
+ return module(*inputs, return_dict=return_dict)
308
+ else:
309
+ return module(*inputs)
310
+
311
+ return custom_forward
312
+
313
+ ckpt_kwargs: Dict[str, Any] = (
314
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
315
+ )
316
+ (
317
+ encoder_hidden_states,
318
+ hidden_states,
319
+ ) = torch.utils.checkpoint.checkpoint(
320
+ create_custom_forward(block),
321
+ hidden_states,
322
+ encoder_hidden_states,
323
+ temb,
324
+ image_rotary_emb,
325
+ **ckpt_kwargs,
326
+ )
327
+
328
+ else:
329
+ encoder_hidden_states, hidden_states = block(
330
+ hidden_states=hidden_states,
331
+ encoder_hidden_states=encoder_hidden_states,
332
+ temb=temb,
333
+ image_rotary_emb=image_rotary_emb,
334
+ )
335
+ block_samples = block_samples + (hidden_states,)
336
+
337
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
338
+
339
+ single_block_samples = ()
340
+ for _, block in enumerate(self.single_transformer_blocks):
341
+ if self.training and self.gradient_checkpointing:
342
+
343
+ def create_custom_forward(module, return_dict=None):
344
+ def custom_forward(*inputs):
345
+ if return_dict is not None:
346
+ return module(*inputs, return_dict=return_dict)
347
+ else:
348
+ return module(*inputs)
349
+
350
+ return custom_forward
351
+
352
+ ckpt_kwargs: Dict[str, Any] = (
353
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
354
+ )
355
+ hidden_states = torch.utils.checkpoint.checkpoint(
356
+ create_custom_forward(block),
357
+ hidden_states,
358
+ temb,
359
+ image_rotary_emb,
360
+ **ckpt_kwargs,
361
+ )
362
+
363
+ else:
364
+ hidden_states = block(
365
+ hidden_states=hidden_states,
366
+ temb=temb,
367
+ image_rotary_emb=image_rotary_emb,
368
+ )
369
+ single_block_samples = single_block_samples + (
370
+ hidden_states[:, encoder_hidden_states.shape[1] :],
371
+ )
372
+
373
+ # controlnet block
374
+ controlnet_block_samples = ()
375
+ for block_sample, controlnet_block in zip(
376
+ block_samples, self.controlnet_blocks
377
+ ):
378
+ block_sample = controlnet_block(block_sample)
379
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
380
+
381
+ controlnet_single_block_samples = ()
382
+ for single_block_sample, controlnet_block in zip(
383
+ single_block_samples, self.controlnet_single_blocks
384
+ ):
385
+ single_block_sample = controlnet_block(single_block_sample)
386
+ controlnet_single_block_samples = controlnet_single_block_samples + (
387
+ single_block_sample,
388
+ )
389
+
390
+ # scaling
391
+ controlnet_block_samples = [
392
+ sample * conditioning_scale for sample in controlnet_block_samples
393
+ ]
394
+ controlnet_single_block_samples = [
395
+ sample * conditioning_scale for sample in controlnet_single_block_samples
396
+ ]
397
+
398
+ #
399
+ controlnet_block_samples = (
400
+ None if len(controlnet_block_samples) == 0 else controlnet_block_samples
401
+ )
402
+ controlnet_single_block_samples = (
403
+ None
404
+ if len(controlnet_single_block_samples) == 0
405
+ else controlnet_single_block_samples
406
+ )
407
+
408
+ if USE_PEFT_BACKEND:
409
+ # remove `lora_scale` from each PEFT layer
410
+ unscale_lora_layers(self, lora_scale)
411
+
412
+ if not return_dict:
413
+ return (controlnet_block_samples, controlnet_single_block_samples)
414
+
415
+ return FluxControlNetOutput(
416
+ controlnet_block_samples=controlnet_block_samples,
417
+ controlnet_single_block_samples=controlnet_single_block_samples,
418
+ )
images/0.jpg ADDED
images/1.jpg ADDED
images/2.jpg ADDED
images/3.jpg ADDED
images/alibaba.png ADDED
images/alibabaalimama.png ADDED
images/alimama.png ADDED
images/flux1.jpg ADDED
images/flux2.jpg ADDED
images/flux3.jpg ADDED
main.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.utils import load_image, check_min_version
3
+ from controlnet_flux import FluxControlNetModel
4
+ from transformer_flux import FluxTransformer2DModel
5
+ from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
6
+
7
+ check_min_version("0.30.2")
8
+
9
+ # Set image path , mask path and prompt
10
+ image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png',
11
+ mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg',
12
+ prompt='a person wearing a white shoe, carrying a white bucket with text "FLUX" on it'
13
+
14
+ # Build pipeline
15
+ controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16)
16
+ transformer = FluxTransformer2DModel.from_pretrained(
17
+ "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16
18
+ )
19
+ pipe = FluxControlNetInpaintingPipeline.from_pretrained(
20
+ "black-forest-labs/FLUX.1-dev",
21
+ controlnet=controlnet,
22
+ transformer=transformer,
23
+ torch_dtype=torch.bfloat16
24
+ ).to("cuda")
25
+ pipe.transformer.to(torch.bfloat16)
26
+ pipe.controlnet.to(torch.bfloat16)
27
+
28
+ # Load image and mask
29
+ size = (768, 768)
30
+ image = load_image(image_path).convert("RGB").resize(size)
31
+ mask = load_image(mask_path).convert("RGB").resize(size)
32
+ generator = torch.Generator(device="cuda").manual_seed(24)
33
+
34
+ # Inpaint
35
+ result = pipe(
36
+ prompt=prompt,
37
+ height=size[1],
38
+ width=size[0],
39
+ control_image=image,
40
+ control_mask=mask,
41
+ num_inference_steps=28,
42
+ generator=generator,
43
+ controlnet_conditioning_scale=0.9,
44
+ guidance_scale=3.5,
45
+ negative_prompt="",
46
+ true_guidance_scale=3.5
47
+ ).images[0]
48
+
49
+ result.save('flux_inpaint.png')
50
+ print("Successfully inpaint image")
pipeline_flux_controlnet_inpaint.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import (
7
+ CLIPTextModel,
8
+ CLIPTokenizer,
9
+ T5EncoderModel,
10
+ T5TokenizerFast,
11
+ )
12
+
13
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
14
+ from diffusers.loaders import FluxLoraLoaderMixin
15
+ from diffusers.models.autoencoders import AutoencoderKL
16
+
17
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
18
+ from diffusers.utils import (
19
+ USE_PEFT_BACKEND,
20
+ is_torch_xla_available,
21
+ logging,
22
+ replace_example_docstring,
23
+ scale_lora_layers,
24
+ unscale_lora_layers,
25
+ )
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
29
+
30
+ from transformer_flux import FluxTransformer2DModel
31
+ from controlnet_flux import FluxControlNetModel
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ EXAMPLE_DOC_STRING = """
44
+ Examples:
45
+ ```py
46
+ >>> import torch
47
+ >>> from diffusers.utils import load_image
48
+ >>> from diffusers import FluxControlNetPipeline
49
+ >>> from diffusers import FluxControlNetModel
50
+
51
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
52
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
53
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
54
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
55
+ ... )
56
+ >>> pipe.to("cuda")
57
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
58
+ >>> control_mask = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
59
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
60
+ >>> image = pipe(
61
+ ... prompt,
62
+ ... control_image=control_image,
63
+ ... controlnet_conditioning_scale=0.6,
64
+ ... num_inference_steps=28,
65
+ ... guidance_scale=3.5,
66
+ ... ).images[0]
67
+ >>> image.save("flux.png")
68
+ ```
69
+ """
70
+
71
+
72
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
73
+ def calculate_shift(
74
+ image_seq_len,
75
+ base_seq_len: int = 256,
76
+ max_seq_len: int = 4096,
77
+ base_shift: float = 0.5,
78
+ max_shift: float = 1.16,
79
+ ):
80
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
81
+ b = base_shift - m * base_seq_len
82
+ mu = image_seq_len * m + b
83
+ return mu
84
+
85
+
86
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
87
+ def retrieve_timesteps(
88
+ scheduler,
89
+ num_inference_steps: Optional[int] = None,
90
+ device: Optional[Union[str, torch.device]] = None,
91
+ timesteps: Optional[List[int]] = None,
92
+ sigmas: Optional[List[float]] = None,
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
97
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
98
+
99
+ Args:
100
+ scheduler (`SchedulerMixin`):
101
+ The scheduler to get timesteps from.
102
+ num_inference_steps (`int`):
103
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
104
+ must be `None`.
105
+ device (`str` or `torch.device`, *optional*):
106
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
107
+ timesteps (`List[int]`, *optional*):
108
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
109
+ `num_inference_steps` and `sigmas` must be `None`.
110
+ sigmas (`List[float]`, *optional*):
111
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
112
+ `num_inference_steps` and `timesteps` must be `None`.
113
+
114
+ Returns:
115
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
116
+ second element is the number of inference steps.
117
+ """
118
+ if timesteps is not None and sigmas is not None:
119
+ raise ValueError(
120
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
121
+ )
122
+ if timesteps is not None:
123
+ accepts_timesteps = "timesteps" in set(
124
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
125
+ )
126
+ if not accepts_timesteps:
127
+ raise ValueError(
128
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
129
+ f" timestep schedules. Please check whether you are using the correct scheduler."
130
+ )
131
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
132
+ timesteps = scheduler.timesteps
133
+ num_inference_steps = len(timesteps)
134
+ elif sigmas is not None:
135
+ accept_sigmas = "sigmas" in set(
136
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
137
+ )
138
+ if not accept_sigmas:
139
+ raise ValueError(
140
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
141
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
142
+ )
143
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ num_inference_steps = len(timesteps)
146
+ else:
147
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
148
+ timesteps = scheduler.timesteps
149
+ return timesteps, num_inference_steps
150
+
151
+
152
+ class FluxControlNetInpaintingPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
153
+ r"""
154
+ The Flux pipeline for text-to-image generation.
155
+
156
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
157
+
158
+ Args:
159
+ transformer ([`FluxTransformer2DModel`]):
160
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
161
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
162
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
163
+ vae ([`AutoencoderKL`]):
164
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
165
+ text_encoder ([`CLIPTextModel`]):
166
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
167
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
168
+ text_encoder_2 ([`T5EncoderModel`]):
169
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
170
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
171
+ tokenizer (`CLIPTokenizer`):
172
+ Tokenizer of class
173
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
174
+ tokenizer_2 (`T5TokenizerFast`):
175
+ Second Tokenizer of class
176
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
177
+ """
178
+
179
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
180
+ _optional_components = []
181
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
182
+
183
+ def __init__(
184
+ self,
185
+ scheduler: FlowMatchEulerDiscreteScheduler,
186
+ vae: AutoencoderKL,
187
+ text_encoder: CLIPTextModel,
188
+ tokenizer: CLIPTokenizer,
189
+ text_encoder_2: T5EncoderModel,
190
+ tokenizer_2: T5TokenizerFast,
191
+ transformer: FluxTransformer2DModel,
192
+ controlnet: FluxControlNetModel,
193
+ ):
194
+ super().__init__()
195
+
196
+ self.register_modules(
197
+ vae=vae,
198
+ text_encoder=text_encoder,
199
+ text_encoder_2=text_encoder_2,
200
+ tokenizer=tokenizer,
201
+ tokenizer_2=tokenizer_2,
202
+ transformer=transformer,
203
+ scheduler=scheduler,
204
+ controlnet=controlnet,
205
+ )
206
+ self.vae_scale_factor = (
207
+ 2 ** (len(self.vae.config.block_out_channels))
208
+ if hasattr(self, "vae") and self.vae is not None
209
+ else 16
210
+ )
211
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True)
212
+ self.mask_processor = VaeImageProcessor(
213
+ vae_scale_factor=self.vae_scale_factor,
214
+ do_resize=True,
215
+ do_convert_grayscale=True,
216
+ do_normalize=False,
217
+ do_binarize=True,
218
+ )
219
+ self.tokenizer_max_length = (
220
+ self.tokenizer.model_max_length
221
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
222
+ else 77
223
+ )
224
+ self.default_sample_size = 64
225
+
226
+ @property
227
+ def do_classifier_free_guidance(self):
228
+ return self._guidance_scale > 1
229
+
230
+ def _get_t5_prompt_embeds(
231
+ self,
232
+ prompt: Union[str, List[str]] = None,
233
+ num_images_per_prompt: int = 1,
234
+ max_sequence_length: int = 512,
235
+ device: Optional[torch.device] = None,
236
+ dtype: Optional[torch.dtype] = None,
237
+ ):
238
+ device = device or self._execution_device
239
+ dtype = dtype or self.text_encoder.dtype
240
+
241
+ prompt = [prompt] if isinstance(prompt, str) else prompt
242
+ batch_size = len(prompt)
243
+
244
+ text_inputs = self.tokenizer_2(
245
+ prompt,
246
+ padding="max_length",
247
+ max_length=max_sequence_length,
248
+ truncation=True,
249
+ return_length=False,
250
+ return_overflowing_tokens=False,
251
+ return_tensors="pt",
252
+ )
253
+ text_input_ids = text_inputs.input_ids
254
+ untruncated_ids = self.tokenizer_2(
255
+ prompt, padding="longest", return_tensors="pt"
256
+ ).input_ids
257
+
258
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
259
+ text_input_ids, untruncated_ids
260
+ ):
261
+ removed_text = self.tokenizer_2.batch_decode(
262
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
263
+ )
264
+ logger.warning(
265
+ "The following part of your input was truncated because `max_sequence_length` is set to "
266
+ f" {max_sequence_length} tokens: {removed_text}"
267
+ )
268
+
269
+ prompt_embeds = self.text_encoder_2(
270
+ text_input_ids.to(device), output_hidden_states=False
271
+ )[0]
272
+
273
+ dtype = self.text_encoder_2.dtype
274
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
275
+
276
+ _, seq_len, _ = prompt_embeds.shape
277
+
278
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
279
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
280
+ prompt_embeds = prompt_embeds.view(
281
+ batch_size * num_images_per_prompt, seq_len, -1
282
+ )
283
+
284
+ return prompt_embeds
285
+
286
+ def _get_clip_prompt_embeds(
287
+ self,
288
+ prompt: Union[str, List[str]],
289
+ num_images_per_prompt: int = 1,
290
+ device: Optional[torch.device] = None,
291
+ ):
292
+ device = device or self._execution_device
293
+
294
+ prompt = [prompt] if isinstance(prompt, str) else prompt
295
+ batch_size = len(prompt)
296
+
297
+ text_inputs = self.tokenizer(
298
+ prompt,
299
+ padding="max_length",
300
+ max_length=self.tokenizer_max_length,
301
+ truncation=True,
302
+ return_overflowing_tokens=False,
303
+ return_length=False,
304
+ return_tensors="pt",
305
+ )
306
+
307
+ text_input_ids = text_inputs.input_ids
308
+ untruncated_ids = self.tokenizer(
309
+ prompt, padding="longest", return_tensors="pt"
310
+ ).input_ids
311
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
312
+ text_input_ids, untruncated_ids
313
+ ):
314
+ removed_text = self.tokenizer.batch_decode(
315
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
316
+ )
317
+ logger.warning(
318
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
319
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
320
+ )
321
+ prompt_embeds = self.text_encoder(
322
+ text_input_ids.to(device), output_hidden_states=False
323
+ )
324
+
325
+ # Use pooled output of CLIPTextModel
326
+ prompt_embeds = prompt_embeds.pooler_output
327
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
328
+
329
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
330
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
331
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
332
+
333
+ return prompt_embeds
334
+
335
+ def encode_prompt(
336
+ self,
337
+ prompt: Union[str, List[str]],
338
+ prompt_2: Union[str, List[str]],
339
+ device: Optional[torch.device] = None,
340
+ num_images_per_prompt: int = 1,
341
+ do_classifier_free_guidance: bool = True,
342
+ negative_prompt: Optional[Union[str, List[str]]] = None,
343
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
344
+ prompt_embeds: Optional[torch.FloatTensor] = None,
345
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
346
+ max_sequence_length: int = 512,
347
+ lora_scale: Optional[float] = None,
348
+ ):
349
+ r"""
350
+
351
+ Args:
352
+ prompt (`str` or `List[str]`, *optional*):
353
+ prompt to be encoded
354
+ prompt_2 (`str` or `List[str]`, *optional*):
355
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
356
+ used in all text-encoders
357
+ device: (`torch.device`):
358
+ torch device
359
+ num_images_per_prompt (`int`):
360
+ number of images that should be generated per prompt
361
+ do_classifier_free_guidance (`bool`):
362
+ whether to use classifier-free guidance or not
363
+ negative_prompt (`str` or `List[str]`, *optional*):
364
+ negative prompt to be encoded
365
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
366
+ negative prompt to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is
367
+ used in all text-encoders
368
+ prompt_embeds (`torch.FloatTensor`, *optional*):
369
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
370
+ provided, text embeddings will be generated from `prompt` input argument.
371
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
372
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
373
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
374
+ clip_skip (`int`, *optional*):
375
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
376
+ the output of the pre-final layer will be used for computing the prompt embeddings.
377
+ lora_scale (`float`, *optional*):
378
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
379
+ """
380
+ device = device or self._execution_device
381
+
382
+ # set lora scale so that monkey patched LoRA
383
+ # function of text encoder can correctly access it
384
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
385
+ self._lora_scale = lora_scale
386
+
387
+ # dynamically adjust the LoRA scale
388
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
389
+ scale_lora_layers(self.text_encoder, lora_scale)
390
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
391
+ scale_lora_layers(self.text_encoder_2, lora_scale)
392
+
393
+ prompt = [prompt] if isinstance(prompt, str) else prompt
394
+ if prompt is not None:
395
+ batch_size = len(prompt)
396
+ else:
397
+ batch_size = prompt_embeds.shape[0]
398
+
399
+ if prompt_embeds is None:
400
+ prompt_2 = prompt_2 or prompt
401
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
402
+
403
+ # We only use the pooled prompt output from the CLIPTextModel
404
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
405
+ prompt=prompt,
406
+ device=device,
407
+ num_images_per_prompt=num_images_per_prompt,
408
+ )
409
+ prompt_embeds = self._get_t5_prompt_embeds(
410
+ prompt=prompt_2,
411
+ num_images_per_prompt=num_images_per_prompt,
412
+ max_sequence_length=max_sequence_length,
413
+ device=device,
414
+ )
415
+
416
+ if do_classifier_free_guidance:
417
+ # 处理 negative prompt
418
+ negative_prompt = negative_prompt or ""
419
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
420
+
421
+ negative_pooled_prompt_embeds = self._get_clip_prompt_embeds(
422
+ negative_prompt,
423
+ device=device,
424
+ num_images_per_prompt=num_images_per_prompt,
425
+ )
426
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
427
+ negative_prompt_2,
428
+ num_images_per_prompt=num_images_per_prompt,
429
+ max_sequence_length=max_sequence_length,
430
+ device=device,
431
+ )
432
+ else:
433
+ negative_pooled_prompt_embeds = None
434
+ negative_prompt_embeds = None
435
+
436
+ if self.text_encoder is not None:
437
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
438
+ # Retrieve the original scale by scaling back the LoRA layers
439
+ unscale_lora_layers(self.text_encoder, lora_scale)
440
+
441
+ if self.text_encoder_2 is not None:
442
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
443
+ # Retrieve the original scale by scaling back the LoRA layers
444
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
445
+
446
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(
447
+ device=device, dtype=self.text_encoder.dtype
448
+ )
449
+
450
+ return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds,text_ids
451
+
452
+ def check_inputs(
453
+ self,
454
+ prompt,
455
+ prompt_2,
456
+ height,
457
+ width,
458
+ prompt_embeds=None,
459
+ pooled_prompt_embeds=None,
460
+ callback_on_step_end_tensor_inputs=None,
461
+ max_sequence_length=None,
462
+ ):
463
+ if height % 8 != 0 or width % 8 != 0:
464
+ raise ValueError(
465
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
466
+ )
467
+
468
+ if callback_on_step_end_tensor_inputs is not None and not all(
469
+ k in self._callback_tensor_inputs
470
+ for k in callback_on_step_end_tensor_inputs
471
+ ):
472
+ raise ValueError(
473
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
474
+ )
475
+
476
+ if prompt is not None and prompt_embeds is not None:
477
+ raise ValueError(
478
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
479
+ " only forward one of the two."
480
+ )
481
+ elif prompt_2 is not None and prompt_embeds is not None:
482
+ raise ValueError(
483
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
484
+ " only forward one of the two."
485
+ )
486
+ elif prompt is None and prompt_embeds is None:
487
+ raise ValueError(
488
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
489
+ )
490
+ elif prompt is not None and (
491
+ not isinstance(prompt, str) and not isinstance(prompt, list)
492
+ ):
493
+ raise ValueError(
494
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
495
+ )
496
+ elif prompt_2 is not None and (
497
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
498
+ ):
499
+ raise ValueError(
500
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
501
+ )
502
+
503
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
504
+ raise ValueError(
505
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
506
+ )
507
+
508
+ if max_sequence_length is not None and max_sequence_length > 512:
509
+ raise ValueError(
510
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
511
+ )
512
+
513
+ # Copied from diffusers.pipelines.flux.pipeline_flux._prepare_latent_image_ids
514
+ @staticmethod
515
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
516
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
517
+ latent_image_ids[..., 1] = (
518
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
519
+ )
520
+ latent_image_ids[..., 2] = (
521
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
522
+ )
523
+
524
+ (
525
+ latent_image_id_height,
526
+ latent_image_id_width,
527
+ latent_image_id_channels,
528
+ ) = latent_image_ids.shape
529
+
530
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
531
+ latent_image_ids = latent_image_ids.reshape(
532
+ batch_size,
533
+ latent_image_id_height * latent_image_id_width,
534
+ latent_image_id_channels,
535
+ )
536
+
537
+ return latent_image_ids.to(device=device, dtype=dtype)
538
+
539
+ # Copied from diffusers.pipelines.flux.pipeline_flux._pack_latents
540
+ @staticmethod
541
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
542
+ latents = latents.view(
543
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
544
+ )
545
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
546
+ latents = latents.reshape(
547
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
548
+ )
549
+
550
+ return latents
551
+
552
+ # Copied from diffusers.pipelines.flux.pipeline_flux._unpack_latents
553
+ @staticmethod
554
+ def _unpack_latents(latents, height, width, vae_scale_factor):
555
+ batch_size, num_patches, channels = latents.shape
556
+
557
+ height = height // vae_scale_factor
558
+ width = width // vae_scale_factor
559
+
560
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
561
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
562
+
563
+ latents = latents.reshape(
564
+ batch_size, channels // (2 * 2), height * 2, width * 2
565
+ )
566
+
567
+ return latents
568
+
569
+ # Copied from diffusers.pipelines.flux.pipeline_flux.prepare_latents
570
+ def prepare_latents(
571
+ self,
572
+ batch_size,
573
+ num_channels_latents,
574
+ height,
575
+ width,
576
+ dtype,
577
+ device,
578
+ generator,
579
+ latents=None,
580
+ ):
581
+ height = 2 * (int(height) // self.vae_scale_factor)
582
+ width = 2 * (int(width) // self.vae_scale_factor)
583
+
584
+ shape = (batch_size, num_channels_latents, height, width)
585
+
586
+ if latents is not None:
587
+ latent_image_ids = self._prepare_latent_image_ids(
588
+ batch_size, height, width, device, dtype
589
+ )
590
+ return latents.to(device=device, dtype=dtype), latent_image_ids
591
+
592
+ if isinstance(generator, list) and len(generator) != batch_size:
593
+ raise ValueError(
594
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
595
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
596
+ )
597
+
598
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
599
+ latents = self._pack_latents(
600
+ latents, batch_size, num_channels_latents, height, width
601
+ )
602
+
603
+ latent_image_ids = self._prepare_latent_image_ids(
604
+ batch_size, height, width, device, dtype
605
+ )
606
+
607
+ return latents, latent_image_ids
608
+
609
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
610
+ def prepare_image(
611
+ self,
612
+ image,
613
+ width,
614
+ height,
615
+ batch_size,
616
+ num_images_per_prompt,
617
+ device,
618
+ dtype,
619
+ ):
620
+ if isinstance(image, torch.Tensor):
621
+ pass
622
+ else:
623
+ image = self.image_processor.preprocess(image, height=height, width=width)
624
+
625
+ image_batch_size = image.shape[0]
626
+
627
+ if image_batch_size == 1:
628
+ repeat_by = batch_size
629
+ else:
630
+ # image batch size is the same as prompt batch size
631
+ repeat_by = num_images_per_prompt
632
+
633
+ image = image.repeat_interleave(repeat_by, dim=0)
634
+
635
+ image = image.to(device=device, dtype=dtype)
636
+
637
+ return image
638
+
639
+ def prepare_image_with_mask(
640
+ self,
641
+ image,
642
+ mask,
643
+ width,
644
+ height,
645
+ batch_size,
646
+ num_images_per_prompt,
647
+ device,
648
+ dtype,
649
+ do_classifier_free_guidance = False,
650
+ ):
651
+ # Prepare image
652
+ if isinstance(image, torch.Tensor):
653
+ pass
654
+ else:
655
+ image = self.image_processor.preprocess(image, height=height, width=width)
656
+
657
+ image_batch_size = image.shape[0]
658
+ if image_batch_size == 1:
659
+ repeat_by = batch_size
660
+ else:
661
+ # image batch size is the same as prompt batch size
662
+ repeat_by = num_images_per_prompt
663
+ image = image.repeat_interleave(repeat_by, dim=0)
664
+ image = image.to(device=device, dtype=dtype)
665
+
666
+ # Prepare mask
667
+ if isinstance(mask, torch.Tensor):
668
+ pass
669
+ else:
670
+ mask = self.mask_processor.preprocess(mask, height=height, width=width)
671
+ mask = mask.repeat_interleave(repeat_by, dim=0)
672
+ mask = mask.to(device=device, dtype=dtype)
673
+
674
+ # Get masked image
675
+ masked_image = image.clone()
676
+ masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
677
+
678
+ # Encode to latents
679
+ image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
680
+ image_latents = (
681
+ image_latents - self.vae.config.shift_factor
682
+ ) * self.vae.config.scaling_factor
683
+ image_latents = image_latents.to(dtype)
684
+
685
+ mask = torch.nn.functional.interpolate(
686
+ mask, size=(height // self.vae_scale_factor * 2, width // self.vae_scale_factor * 2)
687
+ )
688
+ mask = 1 - mask
689
+
690
+ control_image = torch.cat([image_latents, mask], dim=1)
691
+
692
+ # Pack cond latents
693
+ packed_control_image = self._pack_latents(
694
+ control_image,
695
+ batch_size * num_images_per_prompt,
696
+ control_image.shape[1],
697
+ control_image.shape[2],
698
+ control_image.shape[3],
699
+ )
700
+
701
+ if do_classifier_free_guidance:
702
+ packed_control_image = torch.cat([packed_control_image] * 2)
703
+
704
+ return packed_control_image, height, width
705
+
706
+ @property
707
+ def guidance_scale(self):
708
+ return self._guidance_scale
709
+
710
+ @property
711
+ def joint_attention_kwargs(self):
712
+ return self._joint_attention_kwargs
713
+
714
+ @property
715
+ def num_timesteps(self):
716
+ return self._num_timesteps
717
+
718
+ @property
719
+ def interrupt(self):
720
+ return self._interrupt
721
+
722
+ @torch.no_grad()
723
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
724
+ def __call__(
725
+ self,
726
+ prompt: Union[str, List[str]] = None,
727
+ prompt_2: Optional[Union[str, List[str]]] = None,
728
+ height: Optional[int] = None,
729
+ width: Optional[int] = None,
730
+ num_inference_steps: int = 28,
731
+ timesteps: List[int] = None,
732
+ guidance_scale: float = 7.0,
733
+ true_guidance_scale: float = 3.5 ,
734
+ negative_prompt: Optional[Union[str, List[str]]] = None,
735
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
736
+ control_image: PipelineImageInput = None,
737
+ control_mask: PipelineImageInput = None,
738
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
739
+ num_images_per_prompt: Optional[int] = 1,
740
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
741
+ latents: Optional[torch.FloatTensor] = None,
742
+ prompt_embeds: Optional[torch.FloatTensor] = None,
743
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
744
+ output_type: Optional[str] = "pil",
745
+ return_dict: bool = True,
746
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
747
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
748
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
749
+ max_sequence_length: int = 512,
750
+ ):
751
+ r"""
752
+ Function invoked when calling the pipeline for generation.
753
+
754
+ Args:
755
+ prompt (`str` or `List[str]`, *optional*):
756
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
757
+ instead.
758
+ prompt_2 (`str` or `List[str]`, *optional*):
759
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
760
+ will be used instead
761
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
762
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
763
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
764
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
765
+ num_inference_steps (`int`, *optional*, defaults to 50):
766
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
767
+ expense of slower inference.
768
+ timesteps (`List[int]`, *optional*):
769
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
770
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
771
+ passed will be used. Must be in descending order.
772
+ guidance_scale (`float`, *optional*, defaults to 7.0):
773
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
774
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
775
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
776
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
777
+ usually at the expense of lower image quality.
778
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
779
+ The number of images to generate per prompt.
780
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
781
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
782
+ to make generation deterministic.
783
+ latents (`torch.FloatTensor`, *optional*):
784
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
785
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
786
+ tensor will ge generated by sampling using the supplied random `generator`.
787
+ prompt_embeds (`torch.FloatTensor`, *optional*):
788
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
789
+ provided, text embeddings will be generated from `prompt` input argument.
790
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
791
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
792
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
793
+ output_type (`str`, *optional*, defaults to `"pil"`):
794
+ The output format of the generate image. Choose between
795
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
796
+ return_dict (`bool`, *optional*, defaults to `True`):
797
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
798
+ joint_attention_kwargs (`dict`, *optional*):
799
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
800
+ `self.processor` in
801
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
802
+ callback_on_step_end (`Callable`, *optional*):
803
+ A function that calls at the end of each denoising steps during the inference. The function is called
804
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
805
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
806
+ `callback_on_step_end_tensor_inputs`.
807
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
808
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
809
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
810
+ `._callback_tensor_inputs` attribute of your pipeline class.
811
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
812
+
813
+ Examples:
814
+
815
+ Returns:
816
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
817
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
818
+ images.
819
+ """
820
+
821
+ height = height or self.default_sample_size * self.vae_scale_factor
822
+ width = width or self.default_sample_size * self.vae_scale_factor
823
+
824
+ # 1. Check inputs. Raise error if not correct
825
+ self.check_inputs(
826
+ prompt,
827
+ prompt_2,
828
+ height,
829
+ width,
830
+ prompt_embeds=prompt_embeds,
831
+ pooled_prompt_embeds=pooled_prompt_embeds,
832
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
833
+ max_sequence_length=max_sequence_length,
834
+ )
835
+
836
+ self._guidance_scale = true_guidance_scale
837
+ self._joint_attention_kwargs = joint_attention_kwargs
838
+ self._interrupt = False
839
+
840
+ # 2. Define call parameters
841
+ if prompt is not None and isinstance(prompt, str):
842
+ batch_size = 1
843
+ elif prompt is not None and isinstance(prompt, list):
844
+ batch_size = len(prompt)
845
+ else:
846
+ batch_size = prompt_embeds.shape[0]
847
+
848
+ device = self._execution_device
849
+ dtype = self.transformer.dtype
850
+
851
+ lora_scale = (
852
+ self.joint_attention_kwargs.get("scale", None)
853
+ if self.joint_attention_kwargs is not None
854
+ else None
855
+ )
856
+ (
857
+ prompt_embeds,
858
+ pooled_prompt_embeds,
859
+ negative_prompt_embeds,
860
+ negative_pooled_prompt_embeds,
861
+ text_ids
862
+ ) = self.encode_prompt(
863
+ prompt=prompt,
864
+ prompt_2=prompt_2,
865
+ prompt_embeds=prompt_embeds,
866
+ pooled_prompt_embeds=pooled_prompt_embeds,
867
+ do_classifier_free_guidance = self.do_classifier_free_guidance,
868
+ negative_prompt = negative_prompt,
869
+ negative_prompt_2 = negative_prompt_2,
870
+ device=device,
871
+ num_images_per_prompt=num_images_per_prompt,
872
+ max_sequence_length=max_sequence_length,
873
+ lora_scale=lora_scale,
874
+ )
875
+
876
+ # 在 encode_prompt 之后
877
+ if self.do_classifier_free_guidance:
878
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim = 0)
879
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim = 0)
880
+ text_ids = torch.cat([text_ids, text_ids], dim = 0)
881
+
882
+ # 3. Prepare control image
883
+ num_channels_latents = self.transformer.config.in_channels // 4
884
+ if isinstance(self.controlnet, FluxControlNetModel):
885
+ control_image, height, width = self.prepare_image_with_mask(
886
+ image=control_image,
887
+ mask=control_mask,
888
+ width=width,
889
+ height=height,
890
+ batch_size=batch_size * num_images_per_prompt,
891
+ num_images_per_prompt=num_images_per_prompt,
892
+ device=device,
893
+ dtype=dtype,
894
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
895
+ )
896
+
897
+ # 4. Prepare latent variables
898
+ num_channels_latents = self.transformer.config.in_channels // 4
899
+ latents, latent_image_ids = self.prepare_latents(
900
+ batch_size * num_images_per_prompt,
901
+ num_channels_latents,
902
+ height,
903
+ width,
904
+ prompt_embeds.dtype,
905
+ device,
906
+ generator,
907
+ latents,
908
+ )
909
+
910
+ if self.do_classifier_free_guidance:
911
+ latent_image_ids = torch.cat([latent_image_ids] * 2)
912
+
913
+ # 5. Prepare timesteps
914
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
915
+ image_seq_len = latents.shape[1]
916
+ mu = calculate_shift(
917
+ image_seq_len,
918
+ self.scheduler.config.base_image_seq_len,
919
+ self.scheduler.config.max_image_seq_len,
920
+ self.scheduler.config.base_shift,
921
+ self.scheduler.config.max_shift,
922
+ )
923
+ timesteps, num_inference_steps = retrieve_timesteps(
924
+ self.scheduler,
925
+ num_inference_steps,
926
+ device,
927
+ timesteps,
928
+ sigmas,
929
+ mu=mu,
930
+ )
931
+
932
+ num_warmup_steps = max(
933
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
934
+ )
935
+ self._num_timesteps = len(timesteps)
936
+
937
+ # 6. Denoising loop
938
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
939
+ for i, t in enumerate(timesteps):
940
+ if self.interrupt:
941
+ continue
942
+
943
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
944
+
945
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
946
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
947
+
948
+ # handle guidance
949
+ if self.transformer.config.guidance_embeds:
950
+ guidance = torch.tensor([guidance_scale], device=device)
951
+ guidance = guidance.expand(latent_model_input.shape[0])
952
+ else:
953
+ guidance = None
954
+
955
+ # controlnet
956
+ (
957
+ controlnet_block_samples,
958
+ controlnet_single_block_samples,
959
+ ) = self.controlnet(
960
+ hidden_states=latent_model_input,
961
+ controlnet_cond=control_image,
962
+ conditioning_scale=controlnet_conditioning_scale,
963
+ timestep=timestep / 1000,
964
+ guidance=guidance,
965
+ pooled_projections=pooled_prompt_embeds,
966
+ encoder_hidden_states=prompt_embeds,
967
+ txt_ids=text_ids,
968
+ img_ids=latent_image_ids,
969
+ joint_attention_kwargs=self.joint_attention_kwargs,
970
+ return_dict=False,
971
+ )
972
+
973
+ noise_pred = self.transformer(
974
+ hidden_states=latent_model_input,
975
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
976
+ timestep=timestep / 1000,
977
+ guidance=guidance,
978
+ pooled_projections=pooled_prompt_embeds,
979
+ encoder_hidden_states=prompt_embeds,
980
+ controlnet_block_samples=[
981
+ sample.to(dtype=self.transformer.dtype)
982
+ for sample in controlnet_block_samples
983
+ ],
984
+ controlnet_single_block_samples=[
985
+ sample.to(dtype=self.transformer.dtype)
986
+ for sample in controlnet_single_block_samples
987
+ ] if controlnet_single_block_samples is not None else controlnet_single_block_samples,
988
+ txt_ids=text_ids,
989
+ img_ids=latent_image_ids,
990
+ joint_attention_kwargs=self.joint_attention_kwargs,
991
+ return_dict=False,
992
+ )[0]
993
+
994
+ # 在生成循环中
995
+ if self.do_classifier_free_guidance:
996
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
997
+ noise_pred = noise_pred_uncond + true_guidance_scale * (noise_pred_text - noise_pred_uncond)
998
+
999
+ # compute the previous noisy sample x_t -> x_t-1
1000
+ latents_dtype = latents.dtype
1001
+ latents = self.scheduler.step(
1002
+ noise_pred, t, latents, return_dict=False
1003
+ )[0]
1004
+
1005
+ if latents.dtype != latents_dtype:
1006
+ if torch.backends.mps.is_available():
1007
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1008
+ latents = latents.to(latents_dtype)
1009
+
1010
+ if callback_on_step_end is not None:
1011
+ callback_kwargs = {}
1012
+ for k in callback_on_step_end_tensor_inputs:
1013
+ callback_kwargs[k] = locals()[k]
1014
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1015
+
1016
+ latents = callback_outputs.pop("latents", latents)
1017
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1018
+
1019
+ # call the callback, if provided
1020
+ if i == len(timesteps) - 1 or (
1021
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1022
+ ):
1023
+ progress_bar.update()
1024
+
1025
+ if XLA_AVAILABLE:
1026
+ xm.mark_step()
1027
+
1028
+ if output_type == "latent":
1029
+ image = latents
1030
+
1031
+ else:
1032
+ latents = self._unpack_latents(
1033
+ latents, height, width, self.vae_scale_factor
1034
+ )
1035
+ latents = (
1036
+ latents / self.vae.config.scaling_factor
1037
+ ) + self.vae.config.shift_factor
1038
+ latents = latents.to(self.vae.dtype)
1039
+
1040
+ image = self.vae.decode(latents, return_dict=False)[0]
1041
+ image = self.image_processor.postprocess(image, output_type=output_type)
1042
+
1043
+ # Offload all models
1044
+ self.maybe_free_model_hooks()
1045
+
1046
+ if not return_dict:
1047
+ return (image,)
1048
+
1049
+ return FluxPipelineOutput(images=image)
readme.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="display: flex;align-items: center;">
2
+ <img src="images/alibabaalimama.png" alt="alibaba" style="width: 40%; height: auto; margin: 0 10px;">
3
+ </div>
4
+
5
+ This repository provides a Inpainting ControlNet checkpoint for [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) model released by researchers from AlimamaCreative Team.
6
+
7
+ ## News
8
+
9
+ 🎉 Thanks to @comfyanonymous,ComfyUI now supports inference for Alimama inpainting ControlNet. Workflow can be downloaded from [here](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/alimama-flux-controlnet-inpaint.json).
10
+
11
+ ComfyUI Usage Tips:
12
+
13
+ * Using the `t5xxl-FP16` and `flux1-dev-fp8` models for 28-step inference, the GPU memory usage is 27GB. The inference time with `cfg=3.5` is 27 seconds, while without `cfg=1` it is 15 seconds. `Hyper-FLUX-lora` can be used to accelerate inference.
14
+ * You can try adjusting(lower) the parameters `control-strength`, `control-end-percent`, and `cfg` to achieve better results.
15
+ * The following example uses `control-strength` = 0.9 & `control-end-percent` = 1.0 & `cfg` = 3.5
16
+
17
+ | Input | Output | Prompt |
18
+ |------------------------------|------------------------------|-------------|
19
+ | ![Image1](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_1.png) | ![Image2](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_1.png) | <small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, <span style="color:red; font-weight:bold;">Elon Musk</span>, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal. |
20
+ | ![Image3](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_2.png) | ![Image4](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_2.png) | <small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with <span style="color:red; font-weight:bold;">a cat</span> on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. </i></small>|
21
+ | ![Image5](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_3.png) | ![Image6](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_3.png) | <small><i>A woman with blonde hair is sitting on a table wearing a <span style="color:red; font-weight:bold;">red and white long dress</span>. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene. </i></small>|
22
+ | ![Image7](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_in_4.png) | ![Image8](https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/comfy_out_4.png) | <small><i>The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a <span style="color:red; font-weight:bold;">red pencil</span> in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits. </i></small>|
23
+
24
+
25
+ ## Model Cards
26
+
27
+ <!-- 使用HTML来调整图标大小 -->
28
+ <a href="https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha" target="_blank">
29
+ <img src="https://huggingface.co/favicon.ico" alt="Hugging Face" width="25" height="25" /> The model weights have been uploaded to Hugging Face.
30
+ </a>
31
+
32
+ * The model was trained on 12M laion2B and internal source images at resolution 768x768. The inference performs best at this size, with other sizes yielding suboptimal results.
33
+
34
+ * The recommended controlnet_conditioning_scale is 0.9 - 0.95.
35
+
36
+ * **Please note: This is only the alpha version during the training process. We will release an updated version when we feel ready.**
37
+
38
+ ## Showcase
39
+
40
+ ![flux1](images/flux1.jpg)
41
+ ![flux2](images/flux2.jpg)
42
+ ![flux3](images/flux3.jpg)
43
+
44
+ ## Comparison with SDXL-Inpainting
45
+
46
+ Compared with [SDXL-Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1)
47
+
48
+ From left to right: Input image | Masked image | SDXL inpainting | Ours
49
+
50
+ ![0](images/0.jpg)
51
+ <small><i>*The image depicts a beautiful young woman sitting at a desk, reading a book. She has long, wavy brown hair and is wearing a grey shirt with a black cardigan. She is holding a pencil in her left hand and appears to be deep in thought. Surrounding her are numerous books, some stacked on the desk and others placed on a shelf behind her. A potted plant is also visible in the background, adding a touch of greenery to the scene. The image conveys a sense of serenity and intellectual pursuits.*</i></small>
52
+
53
+ ![0](images/1.jpg)
54
+ <small><i>A woman with blonde hair is sitting on a table wearing a blue and white long dress. She is holding a green phone in her hand and appears to be taking a photo. There is a bag next to her on the table and a handbag beside her on the chair. The woman is looking at the phone with a smile on her face. The background includes a TV on the left wall and a couch on the right. A chair is also present in the scene.</i></small>
55
+
56
+ ![0](images/2.jpg)
57
+ <small><i>The image is an illustration of a man standing in a cafe. He is wearing a white turtleneck, a camel-colored trench coat, and brown shoes. He is holding a cell phone and appears to be looking at it. There is a small table with a cup of coffee on it to his right. In the background, there is another man sitting at a table with a laptop. The man is wearing a black turtleneck and a tie. There are several cups and a cake on the table in the background. The man sitting at the table appears to be typing on the laptop.</i></small>
58
+
59
+ ![0](images/3.jpg)
60
+ <small><i>The image depicts a scene from the anime series Dragon Ball Z, with the characters Goku, Naruto, and a child version of Gohan sharing a meal of ramen noodles. They are all sitting around a dining table, with Goku and Gohan on one side and Naruto on the other. They are all holding chopsticks and eating the noodles. The table is set with bowls of ramen, cups, and bowls of drinks. The arrangement of the characters and the food creates a sense of camaraderie and shared enjoyment of the meal.</i></small>
61
+
62
+ ## Using with Diffusers
63
+ Step1: install diffusers
64
+ ``` Shell
65
+ pip install diffusers==0.30.2
66
+ ```
67
+
68
+ Step2: clone repo from github
69
+ ``` Shell
70
+ git clone https://github.com/alimama-creative/FLUX-Controlnet-Inpainting.git
71
+ ```
72
+
73
+ Step3: modify the image_path, mask_path, prompt and run
74
+ ``` Shell
75
+ python main.py
76
+ ```
77
+ ## LICENSE
78
+ Our weights fall under the [FLUX.1 [dev]](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md) Non-Commercial License.
transformer_flux.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
10
+ from diffusers.models.attention import FeedForward
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ FluxAttnProcessor2_0,
14
+ FluxSingleAttnProcessor2_0,
15
+ )
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers.models.normalization import (
18
+ AdaLayerNormContinuous,
19
+ AdaLayerNormZero,
20
+ AdaLayerNormZeroSingle,
21
+ )
22
+ from diffusers.utils import (
23
+ USE_PEFT_BACKEND,
24
+ is_torch_version,
25
+ logging,
26
+ scale_lora_layers,
27
+ unscale_lora_layers,
28
+ )
29
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
30
+ from diffusers.models.embeddings import (
31
+ CombinedTimestepGuidanceTextProjEmbeddings,
32
+ CombinedTimestepTextProjEmbeddings,
33
+ )
34
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ # YiYi to-do: refactor rope related functions/classes
41
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
42
+ assert dim % 2 == 0, "The dimension must be even."
43
+
44
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
45
+ omega = 1.0 / (theta**scale)
46
+
47
+ batch_size, seq_length = pos.shape
48
+ out = torch.einsum("...n,d->...nd", pos, omega)
49
+ cos_out = torch.cos(out)
50
+ sin_out = torch.sin(out)
51
+
52
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
53
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
54
+ return out.float()
55
+
56
+
57
+ # YiYi to-do: refactor rope related functions/classes
58
+ class EmbedND(nn.Module):
59
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
60
+ super().__init__()
61
+ self.dim = dim
62
+ self.theta = theta
63
+ self.axes_dim = axes_dim
64
+
65
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
66
+ n_axes = ids.shape[-1]
67
+ emb = torch.cat(
68
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
69
+ dim=-3,
70
+ )
71
+ return emb.unsqueeze(1)
72
+
73
+
74
+ @maybe_allow_in_graph
75
+ class FluxSingleTransformerBlock(nn.Module):
76
+ r"""
77
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
78
+
79
+ Reference: https://arxiv.org/abs/2403.03206
80
+
81
+ Parameters:
82
+ dim (`int`): The number of channels in the input and output.
83
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
84
+ attention_head_dim (`int`): The number of channels in each head.
85
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
86
+ processing of `context` conditions.
87
+ """
88
+
89
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
90
+ super().__init__()
91
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
92
+
93
+ self.norm = AdaLayerNormZeroSingle(dim)
94
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
95
+ self.act_mlp = nn.GELU(approximate="tanh")
96
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
97
+
98
+ processor = FluxSingleAttnProcessor2_0()
99
+ self.attn = Attention(
100
+ query_dim=dim,
101
+ cross_attention_dim=None,
102
+ dim_head=attention_head_dim,
103
+ heads=num_attention_heads,
104
+ out_dim=dim,
105
+ bias=True,
106
+ processor=processor,
107
+ qk_norm="rms_norm",
108
+ eps=1e-6,
109
+ pre_only=True,
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.FloatTensor,
115
+ temb: torch.FloatTensor,
116
+ image_rotary_emb=None,
117
+ ):
118
+ residual = hidden_states
119
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
120
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
121
+
122
+ attn_output = self.attn(
123
+ hidden_states=norm_hidden_states,
124
+ image_rotary_emb=image_rotary_emb,
125
+ )
126
+
127
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
128
+ gate = gate.unsqueeze(1)
129
+ hidden_states = gate * self.proj_out(hidden_states)
130
+ hidden_states = residual + hidden_states
131
+ if hidden_states.dtype == torch.float16:
132
+ hidden_states = hidden_states.clip(-65504, 65504)
133
+
134
+ return hidden_states
135
+
136
+
137
+ @maybe_allow_in_graph
138
+ class FluxTransformerBlock(nn.Module):
139
+ r"""
140
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
141
+
142
+ Reference: https://arxiv.org/abs/2403.03206
143
+
144
+ Parameters:
145
+ dim (`int`): The number of channels in the input and output.
146
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
147
+ attention_head_dim (`int`): The number of channels in each head.
148
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
149
+ processing of `context` conditions.
150
+ """
151
+
152
+ def __init__(
153
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
154
+ ):
155
+ super().__init__()
156
+
157
+ self.norm1 = AdaLayerNormZero(dim)
158
+
159
+ self.norm1_context = AdaLayerNormZero(dim)
160
+
161
+ if hasattr(F, "scaled_dot_product_attention"):
162
+ processor = FluxAttnProcessor2_0()
163
+ else:
164
+ raise ValueError(
165
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
166
+ )
167
+ self.attn = Attention(
168
+ query_dim=dim,
169
+ cross_attention_dim=None,
170
+ added_kv_proj_dim=dim,
171
+ dim_head=attention_head_dim,
172
+ heads=num_attention_heads,
173
+ out_dim=dim,
174
+ context_pre_only=False,
175
+ bias=True,
176
+ processor=processor,
177
+ qk_norm=qk_norm,
178
+ eps=eps,
179
+ )
180
+
181
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
182
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
183
+
184
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
185
+ self.ff_context = FeedForward(
186
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
187
+ )
188
+
189
+ # let chunk size default to None
190
+ self._chunk_size = None
191
+ self._chunk_dim = 0
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.FloatTensor,
196
+ encoder_hidden_states: torch.FloatTensor,
197
+ temb: torch.FloatTensor,
198
+ image_rotary_emb=None,
199
+ ):
200
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
201
+ hidden_states, emb=temb
202
+ )
203
+
204
+ (
205
+ norm_encoder_hidden_states,
206
+ c_gate_msa,
207
+ c_shift_mlp,
208
+ c_scale_mlp,
209
+ c_gate_mlp,
210
+ ) = self.norm1_context(encoder_hidden_states, emb=temb)
211
+
212
+ # Attention.
213
+ attn_output, context_attn_output = self.attn(
214
+ hidden_states=norm_hidden_states,
215
+ encoder_hidden_states=norm_encoder_hidden_states,
216
+ image_rotary_emb=image_rotary_emb,
217
+ )
218
+
219
+ # Process attention outputs for the `hidden_states`.
220
+ attn_output = gate_msa.unsqueeze(1) * attn_output
221
+ hidden_states = hidden_states + attn_output
222
+
223
+ norm_hidden_states = self.norm2(hidden_states)
224
+ norm_hidden_states = (
225
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
226
+ )
227
+
228
+ ff_output = self.ff(norm_hidden_states)
229
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
230
+
231
+ hidden_states = hidden_states + ff_output
232
+
233
+ # Process attention outputs for the `encoder_hidden_states`.
234
+
235
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
236
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
237
+
238
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
239
+ norm_encoder_hidden_states = (
240
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
241
+ + c_shift_mlp[:, None]
242
+ )
243
+
244
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
245
+ encoder_hidden_states = (
246
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
247
+ )
248
+ if encoder_hidden_states.dtype == torch.float16:
249
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
250
+
251
+ return encoder_hidden_states, hidden_states
252
+
253
+
254
+ class FluxTransformer2DModel(
255
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
256
+ ):
257
+ """
258
+ The Transformer model introduced in Flux.
259
+
260
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
261
+
262
+ Parameters:
263
+ patch_size (`int`): Patch size to turn the input data into small patches.
264
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
265
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
266
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
267
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
268
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
269
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
270
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
271
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
272
+ """
273
+
274
+ _supports_gradient_checkpointing = True
275
+
276
+ @register_to_config
277
+ def __init__(
278
+ self,
279
+ patch_size: int = 1,
280
+ in_channels: int = 64,
281
+ num_layers: int = 19,
282
+ num_single_layers: int = 38,
283
+ attention_head_dim: int = 128,
284
+ num_attention_heads: int = 24,
285
+ joint_attention_dim: int = 4096,
286
+ pooled_projection_dim: int = 768,
287
+ guidance_embeds: bool = False,
288
+ axes_dims_rope: List[int] = [16, 56, 56],
289
+ ):
290
+ super().__init__()
291
+ self.out_channels = in_channels
292
+ self.inner_dim = (
293
+ self.config.num_attention_heads * self.config.attention_head_dim
294
+ )
295
+
296
+ self.pos_embed = EmbedND(
297
+ dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope
298
+ )
299
+ text_time_guidance_cls = (
300
+ CombinedTimestepGuidanceTextProjEmbeddings
301
+ if guidance_embeds
302
+ else CombinedTimestepTextProjEmbeddings
303
+ )
304
+ self.time_text_embed = text_time_guidance_cls(
305
+ embedding_dim=self.inner_dim,
306
+ pooled_projection_dim=self.config.pooled_projection_dim,
307
+ )
308
+
309
+ self.context_embedder = nn.Linear(
310
+ self.config.joint_attention_dim, self.inner_dim
311
+ )
312
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
313
+
314
+ self.transformer_blocks = nn.ModuleList(
315
+ [
316
+ FluxTransformerBlock(
317
+ dim=self.inner_dim,
318
+ num_attention_heads=self.config.num_attention_heads,
319
+ attention_head_dim=self.config.attention_head_dim,
320
+ )
321
+ for i in range(self.config.num_layers)
322
+ ]
323
+ )
324
+
325
+ self.single_transformer_blocks = nn.ModuleList(
326
+ [
327
+ FluxSingleTransformerBlock(
328
+ dim=self.inner_dim,
329
+ num_attention_heads=self.config.num_attention_heads,
330
+ attention_head_dim=self.config.attention_head_dim,
331
+ )
332
+ for i in range(self.config.num_single_layers)
333
+ ]
334
+ )
335
+
336
+ self.norm_out = AdaLayerNormContinuous(
337
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
338
+ )
339
+ self.proj_out = nn.Linear(
340
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
341
+ )
342
+
343
+ self.gradient_checkpointing = False
344
+
345
+ def _set_gradient_checkpointing(self, module, value=False):
346
+ if hasattr(module, "gradient_checkpointing"):
347
+ module.gradient_checkpointing = value
348
+
349
+ def forward(
350
+ self,
351
+ hidden_states: torch.Tensor,
352
+ encoder_hidden_states: torch.Tensor = None,
353
+ pooled_projections: torch.Tensor = None,
354
+ timestep: torch.LongTensor = None,
355
+ img_ids: torch.Tensor = None,
356
+ txt_ids: torch.Tensor = None,
357
+ guidance: torch.Tensor = None,
358
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
359
+ controlnet_block_samples=None,
360
+ controlnet_single_block_samples=None,
361
+ return_dict: bool = True,
362
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
363
+ """
364
+ The [`FluxTransformer2DModel`] forward method.
365
+
366
+ Args:
367
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
368
+ Input `hidden_states`.
369
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
370
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
371
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
372
+ from the embeddings of input conditions.
373
+ timestep ( `torch.LongTensor`):
374
+ Used to indicate denoising step.
375
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
376
+ A list of tensors that if specified are added to the residuals of transformer blocks.
377
+ joint_attention_kwargs (`dict`, *optional*):
378
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
379
+ `self.processor` in
380
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
381
+ return_dict (`bool`, *optional*, defaults to `True`):
382
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
383
+ tuple.
384
+
385
+ Returns:
386
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
387
+ `tuple` where the first element is the sample tensor.
388
+ """
389
+ if joint_attention_kwargs is not None:
390
+ joint_attention_kwargs = joint_attention_kwargs.copy()
391
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
392
+ else:
393
+ lora_scale = 1.0
394
+
395
+ if USE_PEFT_BACKEND:
396
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
397
+ scale_lora_layers(self, lora_scale)
398
+ else:
399
+ if (
400
+ joint_attention_kwargs is not None
401
+ and joint_attention_kwargs.get("scale", None) is not None
402
+ ):
403
+ logger.warning(
404
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
405
+ )
406
+ hidden_states = self.x_embedder(hidden_states)
407
+
408
+ timestep = timestep.to(hidden_states.dtype) * 1000
409
+ if guidance is not None:
410
+ guidance = guidance.to(hidden_states.dtype) * 1000
411
+ else:
412
+ guidance = None
413
+ temb = (
414
+ self.time_text_embed(timestep, pooled_projections)
415
+ if guidance is None
416
+ else self.time_text_embed(timestep, guidance, pooled_projections)
417
+ )
418
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
419
+
420
+ txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
421
+ ids = torch.cat((txt_ids, img_ids), dim=1)
422
+ image_rotary_emb = self.pos_embed(ids)
423
+
424
+ for index_block, block in enumerate(self.transformer_blocks):
425
+ if self.training and self.gradient_checkpointing:
426
+
427
+ def create_custom_forward(module, return_dict=None):
428
+ def custom_forward(*inputs):
429
+ if return_dict is not None:
430
+ return module(*inputs, return_dict=return_dict)
431
+ else:
432
+ return module(*inputs)
433
+
434
+ return custom_forward
435
+
436
+ ckpt_kwargs: Dict[str, Any] = (
437
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
438
+ )
439
+ (
440
+ encoder_hidden_states,
441
+ hidden_states,
442
+ ) = torch.utils.checkpoint.checkpoint(
443
+ create_custom_forward(block),
444
+ hidden_states,
445
+ encoder_hidden_states,
446
+ temb,
447
+ image_rotary_emb,
448
+ **ckpt_kwargs,
449
+ )
450
+
451
+ else:
452
+ encoder_hidden_states, hidden_states = block(
453
+ hidden_states=hidden_states,
454
+ encoder_hidden_states=encoder_hidden_states,
455
+ temb=temb,
456
+ image_rotary_emb=image_rotary_emb,
457
+ )
458
+
459
+ # controlnet residual
460
+ if controlnet_block_samples is not None:
461
+ interval_control = len(self.transformer_blocks) / len(
462
+ controlnet_block_samples
463
+ )
464
+ interval_control = int(np.ceil(interval_control))
465
+ hidden_states = (
466
+ hidden_states
467
+ + controlnet_block_samples[index_block // interval_control]
468
+ )
469
+
470
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
471
+
472
+ for index_block, block in enumerate(self.single_transformer_blocks):
473
+ if self.training and self.gradient_checkpointing:
474
+
475
+ def create_custom_forward(module, return_dict=None):
476
+ def custom_forward(*inputs):
477
+ if return_dict is not None:
478
+ return module(*inputs, return_dict=return_dict)
479
+ else:
480
+ return module(*inputs)
481
+
482
+ return custom_forward
483
+
484
+ ckpt_kwargs: Dict[str, Any] = (
485
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
486
+ )
487
+ hidden_states = torch.utils.checkpoint.checkpoint(
488
+ create_custom_forward(block),
489
+ hidden_states,
490
+ temb,
491
+ image_rotary_emb,
492
+ **ckpt_kwargs,
493
+ )
494
+
495
+ else:
496
+ hidden_states = block(
497
+ hidden_states=hidden_states,
498
+ temb=temb,
499
+ image_rotary_emb=image_rotary_emb,
500
+ )
501
+
502
+ # controlnet residual
503
+ if controlnet_single_block_samples is not None:
504
+ interval_control = len(self.single_transformer_blocks) / len(
505
+ controlnet_single_block_samples
506
+ )
507
+ interval_control = int(np.ceil(interval_control))
508
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
509
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
510
+ + controlnet_single_block_samples[index_block // interval_control]
511
+ )
512
+
513
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
514
+
515
+ hidden_states = self.norm_out(hidden_states, temb)
516
+ output = self.proj_out(hidden_states)
517
+
518
+ if USE_PEFT_BACKEND:
519
+ # remove `lora_scale` from each PEFT layer
520
+ unscale_lora_layers(self, lora_scale)
521
+
522
+ if not return_dict:
523
+ return (output,)
524
+
525
+ return Transformer2DModelOutput(sample=output)