instantx-admin commited on
Commit
70179d9
1 Parent(s): 5df0af7

Upload 5 files

Browse files
attention_processor.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers.models.normalization import FP32LayerNorm, RMSNorm
5
+ from typing import Callable, List, Optional, Tuple, Union
6
+ import math
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+
12
+ class IPAFluxAttnProcessor2_0(nn.Module):
13
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
14
+
15
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
16
+ super().__init__()
17
+
18
+ self.hidden_size = hidden_size # 3072
19
+ self.cross_attention_dim = cross_attention_dim # 4096
20
+ self.scale = scale
21
+ self.num_tokens = num_tokens
22
+
23
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
24
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
25
+
26
+ self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
27
+ #self.norm_added_v = RMSNorm(128, eps=1e-5, elementwise_affine=False)
28
+
29
+ def __call__(
30
+ self,
31
+ attn,
32
+ hidden_states: torch.FloatTensor,
33
+ image_emb: torch.FloatTensor,
34
+ encoder_hidden_states: torch.FloatTensor = None,
35
+ attention_mask: Optional[torch.FloatTensor] = None,
36
+ image_rotary_emb: Optional[torch.Tensor] = None,
37
+ mask: Optional[torch.Tensor] = None,
38
+ ) -> torch.FloatTensor:
39
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+
41
+ # `sample` projections.
42
+ query = attn.to_q(hidden_states)
43
+ key = attn.to_k(hidden_states)
44
+ value = attn.to_v(hidden_states)
45
+
46
+ inner_dim = key.shape[-1]
47
+ head_dim = inner_dim // attn.heads
48
+
49
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # torch.Size([1, 24, 4800, 128])
50
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
51
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
52
+
53
+ if attn.norm_q is not None:
54
+ query = attn.norm_q(query)
55
+ if attn.norm_k is not None:
56
+ key = attn.norm_k(key)
57
+
58
+ if image_emb is not None:
59
+ # `ip-adapter` projections
60
+ ip_hidden_states = image_emb
61
+ ip_hidden_states_key_proj = self.to_k_ip(ip_hidden_states)
62
+ ip_hidden_states_value_proj = self.to_v_ip(ip_hidden_states)
63
+
64
+ ip_hidden_states_key_proj = ip_hidden_states_key_proj.view(
65
+ batch_size, -1, attn.heads, head_dim
66
+ ).transpose(1, 2)
67
+ ip_hidden_states_value_proj = ip_hidden_states_value_proj.view(
68
+ batch_size, -1, attn.heads, head_dim
69
+ ).transpose(1, 2)
70
+
71
+ ip_hidden_states_key_proj = self.norm_added_k(ip_hidden_states_key_proj)
72
+ #ip_hidden_states_valye_proj = self.norm_added_v(ip_hidden_states_value_proj)
73
+
74
+ ip_hidden_states = F.scaled_dot_product_attention(query,
75
+ ip_hidden_states_key_proj,
76
+ ip_hidden_states_value_proj,
77
+ dropout_p=0.0, is_causal=False)
78
+
79
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
80
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
81
+
82
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
83
+ if encoder_hidden_states is not None:
84
+
85
+ # `context` projections.
86
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
87
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
88
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
89
+
90
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
91
+ batch_size, -1, attn.heads, head_dim
92
+ ).transpose(1, 2)
93
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
94
+ batch_size, -1, attn.heads, head_dim
95
+ ).transpose(1, 2)
96
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
97
+ batch_size, -1, attn.heads, head_dim
98
+ ).transpose(1, 2)
99
+
100
+ if attn.norm_added_q is not None:
101
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
102
+ if attn.norm_added_k is not None:
103
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
104
+
105
+ # attention
106
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
107
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
108
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) # (512+3840,128)
109
+
110
+ if image_rotary_emb is not None:
111
+ from diffusers.models.embeddings import apply_rotary_emb
112
+
113
+ query = apply_rotary_emb(query, image_rotary_emb)
114
+ key = apply_rotary_emb(key, image_rotary_emb)
115
+
116
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
117
+
118
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
119
+ hidden_states = hidden_states.to(query.dtype)
120
+
121
+ if encoder_hidden_states is not None:
122
+
123
+ encoder_hidden_states, hidden_states = (
124
+ hidden_states[:, : encoder_hidden_states.shape[1]],
125
+ hidden_states[:, encoder_hidden_states.shape[1] :],
126
+ )
127
+ if image_emb is not None:
128
+ hidden_states = hidden_states + self.scale * ip_hidden_states
129
+
130
+ # linear proj
131
+ hidden_states = attn.to_out[0](hidden_states)
132
+ # dropout
133
+ hidden_states = attn.to_out[1](hidden_states)
134
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
135
+
136
+ return hidden_states, encoder_hidden_states
137
+ else:
138
+ if image_emb is not None:
139
+ hidden_states = hidden_states + self.scale * ip_hidden_states
140
+
141
+ return hidden_states
infer_flux_ipa_siglip.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from pipeline_flux_ipa import FluxPipeline
10
+ from transformer_flux import FluxTransformer2DModel
11
+ from attention_processor import IPAFluxAttnProcessor2_0
12
+ from transformers import AutoProcessor, SiglipVisionModel
13
+
14
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
15
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
16
+
17
+ w, h = input_image.size
18
+ if size is not None:
19
+ w_resize_new, h_resize_new = size
20
+ else:
21
+ ratio = min_side / min(h, w)
22
+ w, h = round(ratio*w), round(ratio*h)
23
+ ratio = max_side / max(h, w)
24
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
25
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
26
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
27
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
28
+
29
+ if pad_to_max_side:
30
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
31
+ offset_x = (max_side - w_resize_new) // 2
32
+ offset_y = (max_side - h_resize_new) // 2
33
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
34
+ input_image = Image.fromarray(res)
35
+ return input_image
36
+
37
+ class MLPProjModel(torch.nn.Module):
38
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
39
+ super().__init__()
40
+
41
+ self.cross_attention_dim = cross_attention_dim
42
+ self.num_tokens = num_tokens
43
+
44
+ self.proj = torch.nn.Sequential(
45
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
46
+ torch.nn.GELU(),
47
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
48
+ )
49
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
50
+
51
+ def forward(self, id_embeds):
52
+ x = self.proj(id_embeds)
53
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
54
+ x = self.norm(x)
55
+ return x
56
+
57
+ class IPAdapter:
58
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
59
+ self.device = device
60
+ self.image_encoder_path = image_encoder_path
61
+ self.ip_ckpt = ip_ckpt
62
+ self.num_tokens = num_tokens
63
+
64
+ self.pipe = sd_pipe.to(self.device)
65
+ self.set_ip_adapter()
66
+
67
+ # load image encoder
68
+ self.image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path).to(self.device, dtype=torch.bfloat16)
69
+ self.clip_image_processor = AutoProcessor.from_pretrained(self.image_encoder_path)
70
+
71
+ # image proj model
72
+ self.image_proj_model = self.init_proj()
73
+
74
+ self.load_ip_adapter()
75
+
76
+ def init_proj(self):
77
+ image_proj_model = MLPProjModel(
78
+ cross_attention_dim=self.pipe.transformer.config.joint_attention_dim, # 4096
79
+ id_embeddings_dim=1152,
80
+ num_tokens=self.num_tokens,
81
+ ).to(self.device, dtype=torch.bfloat16)
82
+
83
+ return image_proj_model
84
+
85
+ def set_ip_adapter(self):
86
+ transformer = self.pipe.transformer
87
+ ip_attn_procs = {} # 19+38=57
88
+ for name in transformer.attn_processors.keys():
89
+ if name.startswith("transformer_blocks.") or name.startswith("single_transformer_blocks"):
90
+ ip_attn_procs[name] = IPAFluxAttnProcessor2_0(
91
+ hidden_size=transformer.config.num_attention_heads * transformer.config.attention_head_dim,
92
+ cross_attention_dim=transformer.config.joint_attention_dim,
93
+ num_tokens=self.num_tokens,
94
+ ).to(self.device, dtype=torch.bfloat16)
95
+ else:
96
+ ip_attn_procs[name] = transformer.attn_processors[name]
97
+
98
+ transformer.set_attn_processor(ip_attn_procs)
99
+
100
+ def load_ip_adapter(self):
101
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
102
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
103
+ ip_layers = torch.nn.ModuleList(self.pipe.transformer.attn_processors.values())
104
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
105
+
106
+ @torch.inference_mode()
107
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
108
+ if pil_image is not None:
109
+ if isinstance(pil_image, Image.Image):
110
+ pil_image = [pil_image]
111
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
112
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=self.image_encoder.dtype)).pooler_output
113
+ clip_image_embeds = clip_image_embeds.to(dtype=torch.bfloat16)
114
+ else:
115
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.bfloat16)
116
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
117
+ return image_prompt_embeds
118
+
119
+ def set_scale(self, scale):
120
+ for attn_processor in self.pipe.transformer.attn_processors.values():
121
+ if isinstance(attn_processor, IPAFluxAttnProcessor2_0):
122
+ attn_processor.scale = scale
123
+
124
+ def generate(
125
+ self,
126
+ pil_image=None,
127
+ clip_image_embeds=None,
128
+ prompt=None,
129
+ scale=1.0,
130
+ num_samples=1,
131
+ seed=None,
132
+ guidance_scale=3.5,
133
+ num_inference_steps=24,
134
+ **kwargs,
135
+ ):
136
+ self.set_scale(scale)
137
+
138
+ image_prompt_embeds = self.get_image_embeds(
139
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds
140
+ )
141
+
142
+ if seed is None:
143
+ generator = None
144
+ else:
145
+ generator = torch.Generator(self.device).manual_seed(seed)
146
+
147
+ images = self.pipe(
148
+ prompt=prompt,
149
+ image_emb=image_prompt_embeds,
150
+ guidance_scale=guidance_scale,
151
+ num_inference_steps=num_inference_steps,
152
+ generator=generator,
153
+ **kwargs,
154
+ ).images
155
+
156
+ return images
157
+
158
+
159
+ if __name__ == '__main__':
160
+
161
+ model_path = "black-forest-labs/FLUX.1-dev"
162
+ image_encoder_path = "google/siglip-so400m-patch14-384"
163
+ ipadapter_path = "./ip-adapter.bin"
164
+
165
+ transformer = FluxTransformer2DModel.from_pretrained(
166
+ model_path, subfolder="transformer", torch_dtype=torch.bfloat16
167
+ )
168
+
169
+ pipe = FluxPipeline.from_pretrained(
170
+ model_path, transformer=transformer, torch_dtype=torch.bfloat16
171
+ )
172
+
173
+ ip_model = IPAdapter(pipe, image_encoder_path, ipadapter_path, device="cuda", num_tokens=128)
174
+
175
+ image_dir = "./assets/images/2.jpg"
176
+ image_name = image_dir.split("/")[-1]
177
+ image = Image.open(image_dir).convert("RGB")
178
+ image = resize_img(image)
179
+
180
+ prompt = "a young girl"
181
+
182
+ images = ip_model.generate(
183
+ pil_image=image,
184
+ prompt=prompt,
185
+ scale=0.7,
186
+ width=960, height=1280,
187
+ seed=42
188
+ )
189
+
190
+ images[0].save(f"results/{image_name}")
ip-adapter.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01aaeb0dace9ea4feee1007d30e41c59847041c116db8c2bcc14f9c91d491203
3
+ size 5291249450
pipeline_flux_ipa.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from diffusers.models.transformers import FluxTransformer2DModel
26
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
30
+ logging,
31
+ replace_example_docstring,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
38
+
39
+
40
+ if is_torch_xla_available():
41
+ import torch_xla.core.xla_model as xm
42
+
43
+ XLA_AVAILABLE = True
44
+ else:
45
+ XLA_AVAILABLE = False
46
+
47
+ from PIL import Image
48
+ import numpy as np
49
+ import torch
50
+ import torch.nn.functional as F
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from diffusers import FluxPipeline
59
+
60
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
61
+ >>> pipe.to("cuda")
62
+ >>> prompt = "A cat holding a sign that says hello world"
63
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
64
+ >>> # Refer to the pipeline documentation for more details.
65
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
66
+ >>> image.save("flux.png")
67
+ ```
68
+ """
69
+
70
+
71
+ def calculate_shift(
72
+ image_seq_len,
73
+ base_seq_len: int = 256,
74
+ max_seq_len: int = 4096,
75
+ base_shift: float = 0.5,
76
+ max_shift: float = 1.16,
77
+ ):
78
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
79
+ b = base_shift - m * base_seq_len
80
+ mu = image_seq_len * m + b
81
+ return mu
82
+
83
+
84
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
85
+ def retrieve_timesteps(
86
+ scheduler,
87
+ num_inference_steps: Optional[int] = None,
88
+ device: Optional[Union[str, torch.device]] = None,
89
+ timesteps: Optional[List[int]] = None,
90
+ sigmas: Optional[List[float]] = None,
91
+ **kwargs,
92
+ ):
93
+ """
94
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
95
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
96
+
97
+ Args:
98
+ scheduler (`SchedulerMixin`):
99
+ The scheduler to get timesteps from.
100
+ num_inference_steps (`int`):
101
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
102
+ must be `None`.
103
+ device (`str` or `torch.device`, *optional*):
104
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
105
+ timesteps (`List[int]`, *optional*):
106
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
107
+ `num_inference_steps` and `sigmas` must be `None`.
108
+ sigmas (`List[float]`, *optional*):
109
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
110
+ `num_inference_steps` and `timesteps` must be `None`.
111
+
112
+ Returns:
113
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
114
+ second element is the number of inference steps.
115
+ """
116
+ if timesteps is not None and sigmas is not None:
117
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
118
+ if timesteps is not None:
119
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
120
+ if not accepts_timesteps:
121
+ raise ValueError(
122
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
123
+ f" timestep schedules. Please check whether you are using the correct scheduler."
124
+ )
125
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
126
+ timesteps = scheduler.timesteps
127
+ num_inference_steps = len(timesteps)
128
+ elif sigmas is not None:
129
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
130
+ if not accept_sigmas:
131
+ raise ValueError(
132
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
133
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
134
+ )
135
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
136
+ timesteps = scheduler.timesteps
137
+ num_inference_steps = len(timesteps)
138
+ else:
139
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
140
+ timesteps = scheduler.timesteps
141
+ return timesteps, num_inference_steps
142
+
143
+
144
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
145
+ r"""
146
+ The Flux pipeline for text-to-image generation.
147
+
148
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
149
+
150
+ Args:
151
+ transformer ([`FluxTransformer2DModel`]):
152
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
153
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
154
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
155
+ vae ([`AutoencoderKL`]):
156
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
157
+ text_encoder ([`CLIPTextModel`]):
158
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
159
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
160
+ text_encoder_2 ([`T5EncoderModel`]):
161
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
162
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
163
+ tokenizer (`CLIPTokenizer`):
164
+ Tokenizer of class
165
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
166
+ tokenizer_2 (`T5TokenizerFast`):
167
+ Second Tokenizer of class
168
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
169
+ """
170
+
171
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
172
+ _optional_components = []
173
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
174
+
175
+ def __init__(
176
+ self,
177
+ scheduler: FlowMatchEulerDiscreteScheduler,
178
+ vae: AutoencoderKL,
179
+ text_encoder: CLIPTextModel,
180
+ tokenizer: CLIPTokenizer,
181
+ text_encoder_2: T5EncoderModel,
182
+ tokenizer_2: T5TokenizerFast,
183
+ transformer: FluxTransformer2DModel,
184
+ ):
185
+ super().__init__()
186
+
187
+ self.register_modules(
188
+ vae=vae,
189
+ text_encoder=text_encoder,
190
+ text_encoder_2=text_encoder_2,
191
+ tokenizer=tokenizer,
192
+ tokenizer_2=tokenizer_2,
193
+ transformer=transformer,
194
+ scheduler=scheduler,
195
+ )
196
+ self.vae_scale_factor = (
197
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
198
+ )
199
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
200
+ self.tokenizer_max_length = (
201
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
202
+ )
203
+ self.default_sample_size = 64
204
+
205
+ def _get_t5_prompt_embeds(
206
+ self,
207
+ prompt: Union[str, List[str]] = None,
208
+ num_images_per_prompt: int = 1,
209
+ max_sequence_length: int = 512,
210
+ device: Optional[torch.device] = None,
211
+ dtype: Optional[torch.dtype] = None,
212
+ ):
213
+ device = device or self._execution_device
214
+ dtype = dtype or self.text_encoder.dtype
215
+
216
+ prompt = [prompt] if isinstance(prompt, str) else prompt
217
+ batch_size = len(prompt)
218
+
219
+ text_inputs = self.tokenizer_2(
220
+ prompt,
221
+ padding="max_length",
222
+ max_length=max_sequence_length,
223
+ truncation=True,
224
+ return_length=False,
225
+ return_overflowing_tokens=False,
226
+ return_tensors="pt",
227
+ )
228
+ text_input_ids = text_inputs.input_ids
229
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
230
+
231
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
232
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
233
+ logger.warning(
234
+ "The following part of your input was truncated because `max_sequence_length` is set to "
235
+ f" {max_sequence_length} tokens: {removed_text}"
236
+ )
237
+
238
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
239
+
240
+ dtype = self.text_encoder_2.dtype
241
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
242
+
243
+ _, seq_len, _ = prompt_embeds.shape
244
+
245
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
246
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
247
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
248
+
249
+ return prompt_embeds
250
+
251
+ def _get_clip_prompt_embeds(
252
+ self,
253
+ prompt: Union[str, List[str]],
254
+ num_images_per_prompt: int = 1,
255
+ device: Optional[torch.device] = None,
256
+ ):
257
+ device = device or self._execution_device
258
+
259
+ prompt = [prompt] if isinstance(prompt, str) else prompt
260
+ batch_size = len(prompt)
261
+
262
+ text_inputs = self.tokenizer(
263
+ prompt,
264
+ padding="max_length",
265
+ max_length=self.tokenizer_max_length,
266
+ truncation=True,
267
+ return_overflowing_tokens=False,
268
+ return_length=False,
269
+ return_tensors="pt",
270
+ )
271
+
272
+ text_input_ids = text_inputs.input_ids
273
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
274
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
275
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
276
+ logger.warning(
277
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
278
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
279
+ )
280
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
281
+
282
+ # Use pooled output of CLIPTextModel
283
+ prompt_embeds = prompt_embeds.pooler_output
284
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
285
+
286
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
287
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
288
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
289
+
290
+ return prompt_embeds
291
+
292
+ def encode_prompt(
293
+ self,
294
+ prompt: Union[str, List[str]],
295
+ prompt_2: Union[str, List[str]],
296
+ device: Optional[torch.device] = None,
297
+ num_images_per_prompt: int = 1,
298
+ prompt_embeds: Optional[torch.FloatTensor] = None,
299
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
300
+ max_sequence_length: int = 512,
301
+ lora_scale: Optional[float] = None,
302
+ ):
303
+ r"""
304
+
305
+ Args:
306
+ prompt (`str` or `List[str]`, *optional*):
307
+ prompt to be encoded
308
+ prompt_2 (`str` or `List[str]`, *optional*):
309
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
310
+ used in all text-encoders
311
+ device: (`torch.device`):
312
+ torch device
313
+ num_images_per_prompt (`int`):
314
+ number of images that should be generated per prompt
315
+ prompt_embeds (`torch.FloatTensor`, *optional*):
316
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
317
+ provided, text embeddings will be generated from `prompt` input argument.
318
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
319
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
320
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
321
+ lora_scale (`float`, *optional*):
322
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
323
+ """
324
+ device = device or self._execution_device
325
+
326
+ # set lora scale so that monkey patched LoRA
327
+ # function of text encoder can correctly access it
328
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
329
+ self._lora_scale = lora_scale
330
+
331
+ # dynamically adjust the LoRA scale
332
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
333
+ scale_lora_layers(self.text_encoder, lora_scale)
334
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
335
+ scale_lora_layers(self.text_encoder_2, lora_scale)
336
+
337
+ prompt = [prompt] if isinstance(prompt, str) else prompt
338
+
339
+ if prompt_embeds is None:
340
+ prompt_2 = prompt_2 or prompt
341
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
342
+
343
+ # We only use the pooled prompt output from the CLIPTextModel
344
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
345
+ prompt=prompt,
346
+ device=device,
347
+ num_images_per_prompt=num_images_per_prompt,
348
+ )
349
+ prompt_embeds = self._get_t5_prompt_embeds(
350
+ prompt=prompt_2,
351
+ num_images_per_prompt=num_images_per_prompt,
352
+ max_sequence_length=max_sequence_length,
353
+ device=device,
354
+ )
355
+
356
+ if self.text_encoder is not None:
357
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
358
+ # Retrieve the original scale by scaling back the LoRA layers
359
+ unscale_lora_layers(self.text_encoder, lora_scale)
360
+
361
+ if self.text_encoder_2 is not None:
362
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
363
+ # Retrieve the original scale by scaling back the LoRA layers
364
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
365
+
366
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
367
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
368
+
369
+ return prompt_embeds, pooled_prompt_embeds, text_ids
370
+
371
+ def encode_regional_prompt(
372
+ self,
373
+ prompt: Union[str, List[str]],
374
+ prompt_2: Union[str, List[str]],
375
+ device: Optional[torch.device] = None,
376
+ num_images_per_prompt: int = 1,
377
+ prompt_embeds: Optional[torch.FloatTensor] = None,
378
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
379
+ max_sequence_length: int = 512,
380
+ lora_scale: Optional[float] = None,
381
+ ):
382
+ r"""
383
+
384
+ Args:
385
+ prompt (`str` or `List[str]`, *optional*):
386
+ prompt to be encoded
387
+ prompt_2 (`str` or `List[str]`, *optional*):
388
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
389
+ used in all text-encoders
390
+ device: (`torch.device`):
391
+ torch device
392
+ num_images_per_prompt (`int`):
393
+ number of images that should be generated per prompt
394
+ prompt_embeds (`torch.FloatTensor`, *optional*):
395
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
396
+ provided, text embeddings will be generated from `prompt` input argument.
397
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
398
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
399
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
400
+ lora_scale (`float`, *optional*):
401
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
402
+ """
403
+ device = device or self._execution_device
404
+
405
+ # set lora scale so that monkey patched LoRA
406
+ # function of text encoder can correctly access it
407
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
408
+ self._lora_scale = lora_scale
409
+
410
+ # dynamically adjust the LoRA scale
411
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
412
+ scale_lora_layers(self.text_encoder, lora_scale)
413
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
414
+ scale_lora_layers(self.text_encoder_2, lora_scale)
415
+
416
+ prompt = [prompt] if isinstance(prompt, str) else prompt
417
+
418
+ if prompt_embeds is None:
419
+ prompt_2 = prompt_2 or prompt
420
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
421
+
422
+ # We only use the pooled prompt output from the CLIPTextModel
423
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
424
+ prompt=prompt,
425
+ device=device,
426
+ num_images_per_prompt=num_images_per_prompt,
427
+ )
428
+ prompt_embeds = self._get_t5_prompt_embeds(
429
+ prompt=prompt_2,
430
+ num_images_per_prompt=num_images_per_prompt,
431
+ max_sequence_length=max_sequence_length,
432
+ device=device,
433
+ )
434
+
435
+ if self.text_encoder is not None:
436
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
437
+ # Retrieve the original scale by scaling back the LoRA layers
438
+ unscale_lora_layers(self.text_encoder, lora_scale)
439
+
440
+ if self.text_encoder_2 is not None:
441
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
442
+ # Retrieve the original scale by scaling back the LoRA layers
443
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
444
+
445
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
446
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
447
+
448
+ # hard code here!
449
+ regional_prompts = prompt[0].split(";")
450
+ prompt_embeds_list = []
451
+ for regional_prompt in regional_prompts:
452
+ prompt_embeds = self._get_t5_prompt_embeds(
453
+ prompt=regional_prompt,
454
+ num_images_per_prompt=num_images_per_prompt,
455
+ max_sequence_length=max_sequence_length,
456
+ device=device,
457
+ )
458
+ prompt_embeds_list.append(prompt_embeds)
459
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=1)
460
+
461
+ #print(prompt_embeds.shape, pooled_prompt_embeds.shape, text_ids.shape)
462
+ # torch.Size([1, 512*num_prompt, 4096]) torch.Size([1, 768]) torch.Size([512, 3])
463
+
464
+ return prompt_embeds, pooled_prompt_embeds, text_ids
465
+
466
+ def check_inputs(
467
+ self,
468
+ prompt,
469
+ prompt_2,
470
+ height,
471
+ width,
472
+ prompt_embeds=None,
473
+ pooled_prompt_embeds=None,
474
+ callback_on_step_end_tensor_inputs=None,
475
+ max_sequence_length=None,
476
+ ):
477
+ if height % 8 != 0 or width % 8 != 0:
478
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
479
+
480
+ if callback_on_step_end_tensor_inputs is not None and not all(
481
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
482
+ ):
483
+ raise ValueError(
484
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
485
+ )
486
+
487
+ if prompt is not None and prompt_embeds is not None:
488
+ raise ValueError(
489
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
490
+ " only forward one of the two."
491
+ )
492
+ elif prompt_2 is not None and prompt_embeds is not None:
493
+ raise ValueError(
494
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
495
+ " only forward one of the two."
496
+ )
497
+ elif prompt is None and prompt_embeds is None:
498
+ raise ValueError(
499
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
500
+ )
501
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
502
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
503
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
504
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
505
+
506
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
507
+ raise ValueError(
508
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
509
+ )
510
+
511
+ if max_sequence_length is not None and max_sequence_length > 512:
512
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
513
+
514
+ @staticmethod
515
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
516
+ # print(batch_size, height, width)
517
+ # 1 96 160
518
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
519
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
520
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
521
+
522
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
523
+
524
+ latent_image_ids = latent_image_ids.reshape(
525
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
526
+ )
527
+
528
+ return latent_image_ids.to(device=device, dtype=dtype)
529
+
530
+ @staticmethod
531
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
532
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
533
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
534
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
535
+
536
+ return latents
537
+
538
+ @staticmethod
539
+ def _unpack_latents(latents, height, width, vae_scale_factor):
540
+ batch_size, num_patches, channels = latents.shape
541
+
542
+ height = height // vae_scale_factor
543
+ width = width // vae_scale_factor
544
+
545
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
546
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
547
+
548
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
549
+
550
+ return latents
551
+
552
+ def enable_vae_slicing(self):
553
+ r"""
554
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
555
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
556
+ """
557
+ self.vae.enable_slicing()
558
+
559
+ def disable_vae_slicing(self):
560
+ r"""
561
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
562
+ computing decoding in one step.
563
+ """
564
+ self.vae.disable_slicing()
565
+
566
+ def enable_vae_tiling(self):
567
+ r"""
568
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
569
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
570
+ processing larger images.
571
+ """
572
+ self.vae.enable_tiling()
573
+
574
+ def disable_vae_tiling(self):
575
+ r"""
576
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
577
+ computing decoding in one step.
578
+ """
579
+ self.vae.disable_tiling()
580
+
581
+ def prepare_latents(
582
+ self,
583
+ batch_size,
584
+ num_channels_latents,
585
+ height,
586
+ width,
587
+ dtype,
588
+ device,
589
+ generator,
590
+ latents=None,
591
+ ):
592
+ height = 2 * (int(height) // self.vae_scale_factor)
593
+ width = 2 * (int(width) // self.vae_scale_factor)
594
+
595
+ shape = (batch_size, num_channels_latents, height, width)
596
+
597
+ if latents is not None:
598
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
599
+ return latents.to(device=device, dtype=dtype), latent_image_ids
600
+
601
+ if isinstance(generator, list) and len(generator) != batch_size:
602
+ raise ValueError(
603
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
604
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
605
+ )
606
+
607
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # torch.Size([1, 16, 96, 160])
608
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) # torch.Size([1, 3840, 64])
609
+
610
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) # torch.Size([3840, 3])
611
+
612
+ return latents, latent_image_ids
613
+
614
+ @property
615
+ def guidance_scale(self):
616
+ return self._guidance_scale
617
+
618
+ @property
619
+ def joint_attention_kwargs(self):
620
+ return self._joint_attention_kwargs
621
+
622
+ @property
623
+ def num_timesteps(self):
624
+ return self._num_timesteps
625
+
626
+ @property
627
+ def interrupt(self):
628
+ return self._interrupt
629
+
630
+ @torch.no_grad()
631
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
632
+ def __call__(
633
+ self,
634
+ prompt: Union[str, List[str]] = None,
635
+ prompt_2: Optional[Union[str, List[str]]] = None,
636
+ height: Optional[int] = None,
637
+ width: Optional[int] = None,
638
+ num_inference_steps: int = 28,
639
+ timesteps: List[int] = None,
640
+ guidance_scale: float = 3.5,
641
+ num_images_per_prompt: Optional[int] = 1,
642
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
643
+ latents: Optional[torch.FloatTensor] = None,
644
+ prompt_embeds: Optional[torch.FloatTensor] = None,
645
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
646
+ image_emb: Optional[torch.FloatTensor] = None,
647
+ output_type: Optional[str] = "pil",
648
+ return_dict: bool = True,
649
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
650
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
651
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
652
+ max_sequence_length: int = 512,
653
+ ):
654
+ r"""
655
+ Function invoked when calling the pipeline for generation.
656
+
657
+ Args:
658
+ prompt (`str` or `List[str]`, *optional*):
659
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
660
+ instead.
661
+ prompt_2 (`str` or `List[str]`, *optional*):
662
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
663
+ will be used instead
664
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
665
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
666
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
667
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
668
+ num_inference_steps (`int`, *optional*, defaults to 50):
669
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
670
+ expense of slower inference.
671
+ timesteps (`List[int]`, *optional*):
672
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
673
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
674
+ passed will be used. Must be in descending order.
675
+ guidance_scale (`float`, *optional*, defaults to 7.0):
676
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
677
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
678
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
679
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
680
+ usually at the expense of lower image quality.
681
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
682
+ The number of images to generate per prompt.
683
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
684
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
685
+ to make generation deterministic.
686
+ latents (`torch.FloatTensor`, *optional*):
687
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
688
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
689
+ tensor will ge generated by sampling using the supplied random `generator`.
690
+ prompt_embeds (`torch.FloatTensor`, *optional*):
691
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
692
+ provided, text embeddings will be generated from `prompt` input argument.
693
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
694
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
695
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
696
+ output_type (`str`, *optional*, defaults to `"pil"`):
697
+ The output format of the generate image. Choose between
698
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
699
+ return_dict (`bool`, *optional*, defaults to `True`):
700
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
701
+ joint_attention_kwargs (`dict`, *optional*):
702
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
703
+ `self.processor` in
704
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
705
+ callback_on_step_end (`Callable`, *optional*):
706
+ A function that calls at the end of each denoising steps during the inference. The function is called
707
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
708
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
709
+ `callback_on_step_end_tensor_inputs`.
710
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
711
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
712
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
713
+ `._callback_tensor_inputs` attribute of your pipeline class.
714
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
715
+
716
+ Examples:
717
+
718
+ Returns:
719
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
720
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
721
+ images.
722
+ """
723
+
724
+ height = height or self.default_sample_size * self.vae_scale_factor
725
+ width = width or self.default_sample_size * self.vae_scale_factor
726
+
727
+ # 1. Check inputs. Raise error if not correct
728
+ self.check_inputs(
729
+ prompt,
730
+ prompt_2,
731
+ height,
732
+ width,
733
+ prompt_embeds=prompt_embeds,
734
+ pooled_prompt_embeds=pooled_prompt_embeds,
735
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
736
+ max_sequence_length=max_sequence_length,
737
+ )
738
+
739
+ self._guidance_scale = guidance_scale
740
+ self._joint_attention_kwargs = joint_attention_kwargs
741
+ self._interrupt = False
742
+
743
+ # 2. Define call parameters
744
+ if prompt is not None and isinstance(prompt, str):
745
+ batch_size = 1
746
+ elif prompt is not None and isinstance(prompt, list):
747
+ batch_size = len(prompt)
748
+ else:
749
+ batch_size = prompt_embeds.shape[0]
750
+
751
+ device = self._execution_device
752
+
753
+ lora_scale = (
754
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
755
+ )
756
+ (
757
+ prompt_embeds,
758
+ pooled_prompt_embeds,
759
+ text_ids,
760
+ ) = self.encode_prompt(
761
+ prompt=prompt,
762
+ prompt_2=prompt_2,
763
+ prompt_embeds=prompt_embeds,
764
+ pooled_prompt_embeds=pooled_prompt_embeds,
765
+ device=device,
766
+ num_images_per_prompt=num_images_per_prompt,
767
+ max_sequence_length=max_sequence_length,
768
+ lora_scale=lora_scale,
769
+ )
770
+
771
+ # 4. Prepare latent variables
772
+ num_channels_latents = self.transformer.config.in_channels // 4
773
+ latents, latent_image_ids = self.prepare_latents(
774
+ batch_size * num_images_per_prompt,
775
+ num_channels_latents,
776
+ height,
777
+ width,
778
+ prompt_embeds.dtype,
779
+ device,
780
+ generator,
781
+ latents,
782
+ )
783
+
784
+ # 5. Prepare timesteps
785
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
786
+ image_seq_len = latents.shape[1]
787
+ mu = calculate_shift(
788
+ image_seq_len,
789
+ self.scheduler.config.base_image_seq_len,
790
+ self.scheduler.config.max_image_seq_len,
791
+ self.scheduler.config.base_shift,
792
+ self.scheduler.config.max_shift,
793
+ )
794
+ timesteps, num_inference_steps = retrieve_timesteps(
795
+ self.scheduler,
796
+ num_inference_steps,
797
+ device,
798
+ timesteps,
799
+ sigmas,
800
+ mu=mu,
801
+ )
802
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
803
+ self._num_timesteps = len(timesteps)
804
+
805
+ # handle guidance
806
+ if self.transformer.config.guidance_embeds:
807
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
808
+ guidance = guidance.expand(latents.shape[0])
809
+ else:
810
+ guidance = None
811
+
812
+ # 6. Denoising loop
813
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
814
+ for i, t in enumerate(timesteps):
815
+ if self.interrupt:
816
+ continue
817
+
818
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
819
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
820
+
821
+ noise_pred = self.transformer(
822
+ hidden_states=latents,
823
+ timestep=timestep / 1000,
824
+ guidance=guidance,
825
+ pooled_projections=pooled_prompt_embeds,
826
+ encoder_hidden_states=prompt_embeds,
827
+ image_emb=image_emb,
828
+ txt_ids=text_ids,
829
+ img_ids=latent_image_ids,
830
+ joint_attention_kwargs=self.joint_attention_kwargs,
831
+ return_dict=False,
832
+ )[0]
833
+
834
+ # compute the previous noisy sample x_t -> x_t-1
835
+ latents_dtype = latents.dtype
836
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
837
+
838
+ if latents.dtype != latents_dtype:
839
+ if torch.backends.mps.is_available():
840
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
841
+ latents = latents.to(latents_dtype)
842
+
843
+ if callback_on_step_end is not None:
844
+ callback_kwargs = {}
845
+ for k in callback_on_step_end_tensor_inputs:
846
+ callback_kwargs[k] = locals()[k]
847
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
848
+
849
+ latents = callback_outputs.pop("latents", latents)
850
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
851
+
852
+ # call the callback, if provided
853
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
854
+ progress_bar.update()
855
+
856
+ if XLA_AVAILABLE:
857
+ xm.mark_step()
858
+
859
+ if output_type == "latent":
860
+ image = latents
861
+
862
+ else:
863
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
864
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
865
+ image = self.vae.decode(latents, return_dict=False)[0]
866
+ image = self.image_processor.postprocess(image, output_type=output_type)
867
+
868
+ # Offload all models
869
+ self.maybe_free_model_hooks()
870
+
871
+ if not return_dict:
872
+ return (image,)
873
+
874
+ return FluxPipelineOutput(images=image)
transformer_flux.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import FeedForward
26
+ from diffusers.models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ FusedFluxAttnProcessor2_0,
31
+ )
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
37
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ @maybe_allow_in_graph
43
+ class FluxSingleTransformerBlock(nn.Module):
44
+ r"""
45
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
46
+
47
+ Reference: https://arxiv.org/abs/2403.03206
48
+
49
+ Parameters:
50
+ dim (`int`): The number of channels in the input and output.
51
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
52
+ attention_head_dim (`int`): The number of channels in each head.
53
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
54
+ processing of `context` conditions.
55
+ """
56
+
57
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
58
+ super().__init__()
59
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
60
+
61
+ self.norm = AdaLayerNormZeroSingle(dim)
62
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
63
+ self.act_mlp = nn.GELU(approximate="tanh")
64
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
65
+
66
+ processor = FluxAttnProcessor2_0()
67
+ self.attn = Attention(
68
+ query_dim=dim,
69
+ cross_attention_dim=None,
70
+ dim_head=attention_head_dim,
71
+ heads=num_attention_heads,
72
+ out_dim=dim,
73
+ bias=True,
74
+ processor=processor,
75
+ qk_norm="rms_norm",
76
+ eps=1e-6,
77
+ pre_only=True,
78
+ )
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.FloatTensor,
83
+ temb: torch.FloatTensor,
84
+ image_emb=None,
85
+ image_rotary_emb=None,
86
+ ):
87
+ residual = hidden_states
88
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
89
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
90
+
91
+ attn_output = self.attn(
92
+ hidden_states=norm_hidden_states,
93
+ image_rotary_emb=image_rotary_emb,
94
+ image_emb=image_emb,
95
+ )
96
+
97
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
98
+ gate = gate.unsqueeze(1) # torch.Size([1, 1, 3072])
99
+ hidden_states = gate * self.proj_out(hidden_states) # torch.Size([1, 4352, 3072])
100
+
101
+ hidden_states = residual + hidden_states
102
+ if hidden_states.dtype == torch.float16:
103
+ hidden_states = hidden_states.clip(-65504, 65504)
104
+
105
+ return hidden_states
106
+
107
+
108
+ @maybe_allow_in_graph
109
+ class FluxTransformerBlock(nn.Module):
110
+ r"""
111
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
112
+
113
+ Reference: https://arxiv.org/abs/2403.03206
114
+
115
+ Parameters:
116
+ dim (`int`): The number of channels in the input and output.
117
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
118
+ attention_head_dim (`int`): The number of channels in each head.
119
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
120
+ processing of `context` conditions.
121
+ """
122
+
123
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
124
+ super().__init__()
125
+
126
+ self.norm1 = AdaLayerNormZero(dim)
127
+
128
+ self.norm1_context = AdaLayerNormZero(dim)
129
+
130
+ if hasattr(F, "scaled_dot_product_attention"):
131
+ processor = FluxAttnProcessor2_0()
132
+ else:
133
+ raise ValueError(
134
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
135
+ )
136
+ self.attn = Attention(
137
+ query_dim=dim,
138
+ cross_attention_dim=None,
139
+ added_kv_proj_dim=dim,
140
+ dim_head=attention_head_dim,
141
+ heads=num_attention_heads,
142
+ out_dim=dim,
143
+ context_pre_only=False,
144
+ bias=True,
145
+ processor=processor,
146
+ qk_norm=qk_norm,
147
+ eps=eps,
148
+ )
149
+
150
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
151
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
152
+
153
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
154
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
155
+
156
+ # let chunk size default to None
157
+ self._chunk_size = None
158
+ self._chunk_dim = 0
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states: torch.FloatTensor,
163
+ encoder_hidden_states: torch.FloatTensor,
164
+ temb: torch.FloatTensor,
165
+ image_emb=None,
166
+ image_rotary_emb=None,
167
+ ):
168
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
169
+
170
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
171
+ encoder_hidden_states, emb=temb
172
+ )
173
+
174
+ # Attention.
175
+ attn_output, context_attn_output = self.attn(
176
+ hidden_states=norm_hidden_states,
177
+ encoder_hidden_states=norm_encoder_hidden_states,
178
+ image_rotary_emb=image_rotary_emb,
179
+ image_emb=image_emb,
180
+ )
181
+
182
+ # Process attention outputs for the `hidden_states`.
183
+ attn_output = gate_msa.unsqueeze(1) * attn_output
184
+ hidden_states = hidden_states + attn_output
185
+
186
+ norm_hidden_states = self.norm2(hidden_states)
187
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
188
+
189
+ ff_output = self.ff(norm_hidden_states)
190
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
191
+ hidden_states = hidden_states + ff_output
192
+
193
+ # Process attention outputs for the `encoder_hidden_states`.
194
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
195
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
196
+
197
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
198
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
199
+
200
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
201
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
202
+ if encoder_hidden_states.dtype == torch.float16:
203
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
204
+
205
+ return encoder_hidden_states, hidden_states
206
+
207
+
208
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
209
+ """
210
+ The Transformer model introduced in Flux.
211
+
212
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
213
+
214
+ Parameters:
215
+ patch_size (`int`): Patch size to turn the input data into small patches.
216
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
217
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
218
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
219
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
220
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
221
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
222
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
223
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
224
+ """
225
+
226
+ _supports_gradient_checkpointing = True
227
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
228
+
229
+ @register_to_config
230
+ def __init__(
231
+ self,
232
+ patch_size: int = 1,
233
+ in_channels: int = 64,
234
+ num_layers: int = 19,
235
+ num_single_layers: int = 38,
236
+ attention_head_dim: int = 128,
237
+ num_attention_heads: int = 24,
238
+ joint_attention_dim: int = 4096,
239
+ pooled_projection_dim: int = 768,
240
+ guidance_embeds: bool = False,
241
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
242
+ ):
243
+ super().__init__()
244
+ self.out_channels = in_channels
245
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
246
+
247
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
248
+
249
+ text_time_guidance_cls = (
250
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
251
+ )
252
+ self.time_text_embed = text_time_guidance_cls(
253
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
254
+ )
255
+
256
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
257
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
258
+
259
+ self.transformer_blocks = nn.ModuleList(
260
+ [
261
+ FluxTransformerBlock(
262
+ dim=self.inner_dim,
263
+ num_attention_heads=self.config.num_attention_heads,
264
+ attention_head_dim=self.config.attention_head_dim,
265
+ )
266
+ for i in range(self.config.num_layers)
267
+ ]
268
+ )
269
+
270
+ self.single_transformer_blocks = nn.ModuleList(
271
+ [
272
+ FluxSingleTransformerBlock(
273
+ dim=self.inner_dim,
274
+ num_attention_heads=self.config.num_attention_heads,
275
+ attention_head_dim=self.config.attention_head_dim,
276
+ )
277
+ for i in range(self.config.num_single_layers)
278
+ ]
279
+ )
280
+
281
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
282
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
283
+
284
+ self.gradient_checkpointing = False
285
+
286
+ @property
287
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
288
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
289
+ r"""
290
+ Returns:
291
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
292
+ indexed by its weight name.
293
+ """
294
+ # set recursively
295
+ processors = {}
296
+
297
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
298
+ if hasattr(module, "get_processor"):
299
+ processors[f"{name}.processor"] = module.get_processor()
300
+
301
+ for sub_name, child in module.named_children():
302
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
303
+
304
+ return processors
305
+
306
+ for name, module in self.named_children():
307
+ fn_recursive_add_processors(name, module, processors)
308
+
309
+ return processors
310
+
311
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
312
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
313
+ r"""
314
+ Sets the attention processor to use to compute attention.
315
+
316
+ Parameters:
317
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
318
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
319
+ for **all** `Attention` layers.
320
+
321
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
322
+ processor. This is strongly recommended when setting trainable attention processors.
323
+
324
+ """
325
+ count = len(self.attn_processors.keys())
326
+
327
+ if isinstance(processor, dict) and len(processor) != count:
328
+ raise ValueError(
329
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
330
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
331
+ )
332
+
333
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
334
+ if hasattr(module, "set_processor"):
335
+ if not isinstance(processor, dict):
336
+ module.set_processor(processor)
337
+ else:
338
+ module.set_processor(processor.pop(f"{name}.processor"))
339
+
340
+ for sub_name, child in module.named_children():
341
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
342
+
343
+ for name, module in self.named_children():
344
+ fn_recursive_attn_processor(name, module, processor)
345
+
346
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
347
+ def fuse_qkv_projections(self):
348
+ """
349
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
350
+ are fused. For cross-attention modules, key and value projection matrices are fused.
351
+
352
+ <Tip warning={true}>
353
+
354
+ This API is 🧪 experimental.
355
+
356
+ </Tip>
357
+ """
358
+ self.original_attn_processors = None
359
+
360
+ for _, attn_processor in self.attn_processors.items():
361
+ if "Added" in str(attn_processor.__class__.__name__):
362
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
363
+
364
+ self.original_attn_processors = self.attn_processors
365
+
366
+ for module in self.modules():
367
+ if isinstance(module, Attention):
368
+ module.fuse_projections(fuse=True)
369
+
370
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
371
+
372
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
373
+ def unfuse_qkv_projections(self):
374
+ """Disables the fused QKV projection if enabled.
375
+
376
+ <Tip warning={true}>
377
+
378
+ This API is 🧪 experimental.
379
+
380
+ </Tip>
381
+
382
+ """
383
+ if self.original_attn_processors is not None:
384
+ self.set_attn_processor(self.original_attn_processors)
385
+
386
+ def _set_gradient_checkpointing(self, module, value=False):
387
+ if hasattr(module, "gradient_checkpointing"):
388
+ module.gradient_checkpointing = value
389
+
390
+ def forward(
391
+ self,
392
+ hidden_states: torch.Tensor,
393
+ encoder_hidden_states: torch.Tensor = None,
394
+ image_emb: torch.FloatTensor = None,
395
+ pooled_projections: torch.Tensor = None,
396
+ timestep: torch.LongTensor = None,
397
+ img_ids: torch.Tensor = None,
398
+ txt_ids: torch.Tensor = None,
399
+ guidance: torch.Tensor = None,
400
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
401
+ controlnet_block_samples=None,
402
+ controlnet_single_block_samples=None,
403
+ return_dict: bool = True,
404
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
405
+ """
406
+ The [`FluxTransformer2DModel`] forward method.
407
+
408
+ Args:
409
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
410
+ Input `hidden_states`.
411
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
412
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
413
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
414
+ from the embeddings of input conditions.
415
+ timestep ( `torch.LongTensor`):
416
+ Used to indicate denoising step.
417
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
418
+ A list of tensors that if specified are added to the residuals of transformer blocks.
419
+ joint_attention_kwargs (`dict`, *optional*):
420
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
421
+ `self.processor` in
422
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
423
+ return_dict (`bool`, *optional*, defaults to `True`):
424
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
425
+ tuple.
426
+
427
+ Returns:
428
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
429
+ `tuple` where the first element is the sample tensor.
430
+ """
431
+ if joint_attention_kwargs is not None:
432
+ joint_attention_kwargs = joint_attention_kwargs.copy()
433
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
434
+ else:
435
+ lora_scale = 1.0
436
+
437
+ if USE_PEFT_BACKEND:
438
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
439
+ scale_lora_layers(self, lora_scale)
440
+ else:
441
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
442
+ logger.warning(
443
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
444
+ )
445
+ hidden_states = self.x_embedder(hidden_states)
446
+
447
+ timestep = timestep.to(hidden_states.dtype) * 1000
448
+ if guidance is not None:
449
+ guidance = guidance.to(hidden_states.dtype) * 1000
450
+ else:
451
+ guidance = None
452
+ temb = (
453
+ self.time_text_embed(timestep, pooled_projections)
454
+ if guidance is None
455
+ else self.time_text_embed(timestep, guidance, pooled_projections)
456
+ )
457
+ # torch.Size([1, 512*num_prompt, 4096]) -> torch.Size([1, 512*num_prompt, 3072])
458
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
459
+
460
+ if txt_ids.ndim == 3:
461
+ logger.warning(
462
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
463
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
464
+ )
465
+ txt_ids = txt_ids[0]
466
+ if img_ids.ndim == 3:
467
+ logger.warning(
468
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
469
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
470
+ )
471
+ img_ids = img_ids[0]
472
+
473
+ ids = torch.cat((txt_ids, img_ids), dim=0)
474
+ image_rotary_emb = self.pos_embed(ids)
475
+
476
+ for index_block, block in enumerate(self.transformer_blocks):
477
+ if self.training and self.gradient_checkpointing:
478
+
479
+ def create_custom_forward(module, return_dict=None):
480
+ def custom_forward(*inputs):
481
+ if return_dict is not None:
482
+ return module(*inputs, return_dict=return_dict)
483
+ else:
484
+ return module(*inputs)
485
+
486
+ return custom_forward
487
+
488
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
489
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
490
+ create_custom_forward(block),
491
+ hidden_states,
492
+ encoder_hidden_states,
493
+ temb,
494
+ image_emb,
495
+ image_rotary_emb,
496
+ **ckpt_kwargs,
497
+ )
498
+
499
+ else:
500
+ encoder_hidden_states, hidden_states = block(
501
+ hidden_states=hidden_states,
502
+ encoder_hidden_states=encoder_hidden_states,
503
+ temb=temb,
504
+ image_emb=image_emb,
505
+ image_rotary_emb=image_rotary_emb,
506
+ )
507
+
508
+ # controlnet residual
509
+ if controlnet_block_samples is not None:
510
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
511
+ interval_control = int(np.ceil(interval_control))
512
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
513
+
514
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
515
+
516
+ for index_block, block in enumerate(self.single_transformer_blocks):
517
+ if self.training and self.gradient_checkpointing:
518
+
519
+ def create_custom_forward(module, return_dict=None):
520
+ def custom_forward(*inputs):
521
+ if return_dict is not None:
522
+ return module(*inputs, return_dict=return_dict)
523
+ else:
524
+ return module(*inputs)
525
+
526
+ return custom_forward
527
+
528
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
529
+ hidden_states = torch.utils.checkpoint.checkpoint(
530
+ create_custom_forward(block),
531
+ hidden_states,
532
+ temb,
533
+ image_emb,
534
+ image_rotary_emb,
535
+ **ckpt_kwargs,
536
+ )
537
+
538
+ else:
539
+ hidden_states = block(
540
+ hidden_states=hidden_states,
541
+ temb=temb,
542
+ image_emb=image_emb,
543
+ image_rotary_emb=image_rotary_emb,
544
+ )
545
+
546
+ # controlnet residual
547
+ if controlnet_single_block_samples is not None:
548
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
549
+ interval_control = int(np.ceil(interval_control))
550
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
551
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
552
+ + controlnet_single_block_samples[index_block // interval_control]
553
+ )
554
+
555
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
556
+
557
+ hidden_states = self.norm_out(hidden_states, temb)
558
+ output = self.proj_out(hidden_states)
559
+
560
+ if USE_PEFT_BACKEND:
561
+ # remove `lora_scale` from each PEFT layer
562
+ unscale_lora_layers(self, lora_scale)
563
+
564
+ if not return_dict:
565
+ return (output,)
566
+
567
+ return Transformer2DModelOutput(sample=output)