EVA787797 commited on
Commit
bbcc985
1 Parent(s): 2738295

Upload 5 files

Browse files
README (20).md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: flux-1-dev-non-commercial-license
4
+ license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md
5
+ base_model:
6
+ - ashen0209/Flux-Dev2Pro
7
+ - black-forest-labs/FLUX.1-schnell
8
+ tags:
9
+ - text-to-image
10
+ - flux
11
+ - merge
12
+ widget:
13
+ - text: Example 1
14
+ output:
15
+ url: images/0015.webp
16
+ - text: Example 2
17
+ output:
18
+ url: images/0022.webp
19
+ - text: Example 3
20
+ output:
21
+ url: images/0012.webp
22
+ - text: Example 4
23
+ output:
24
+ url: images/0014.webp
25
+ - text: Example 5
26
+ output:
27
+ url: images/0033.webp
28
+ - text: Example 6
29
+ output:
30
+ url: images/0032.webp
31
+ ---
32
+
33
+ # Flux D+S F8 Diffusers
34
+
35
+ A + B merge, meant for gradio inference, possibly build adapters on top with [diffusers](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py).
36
+
37
+ <Gallery />
38
+
39
+ ## Inference
40
+
41
+ [Instruction](https://huggingface.co/twodgirl/flux-dev-fp8-e4m3fn-diffusers)
42
+
43
+ This model needs 10 steps, otherwise the images remain blurry.
44
+
45
+ ## ComfyUI
46
+
47
+ I convert all my checkpoints locally. There are 5+ file formats out there, they take up too much space.
48
+
49
+ Download the contents of this repo, then enter the directory. A script will convert the diffusers format to something that is compatible with ComfyUI.
50
+
51
+ ```
52
+ cd flux-devpro-schnell-merge-fp8-e4m3fn-diffusers
53
+ pip install safetensors torch
54
+ python flux_devpro_ckpt.py transformer/diffusion_pytorch_model.safetensors flux-devpro-schnell.safetensors
55
+ mv flux-devpro-schnell.safetensors path/to/your/ComfyUI/models/unet/flux-devpro-schnell.safetensors
56
+ ```
57
+
58
+ In the workflow, the LoadDiffusionModel node should take two values: flux-devpro-schnell.safetensors, fp8_e4m3fn.
59
+
60
+ Use of this code requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.
61
+
62
+ ## Disclaimer
63
+
64
+ Sharing, reusing the model weights requires a link back to authors and their Hugging Face profile. The source models were uploaded by Ashen0209 and BFL.
flux_devpro_ckpt.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ from map_from_diffusers import convert_diffusers_to_flux_checkpoint
3
+ from safetensors.torch import load_file, save_file
4
+ import sys
5
+ import torch
6
+
7
+ ###
8
+ # Code from huggingface/twodgirl
9
+ # License: apache-2.0
10
+
11
+ if __name__ == '__main__':
12
+ sd = convert_diffusers_to_flux_checkpoint(load_file(sys.argv[1]))
13
+ assert sd['time_in.in_layer.weight'].dtype == torch.float8_e4m3fn
14
+ print(len(sd))
15
+ gc.collect()
16
+ save_file(sd, sys.argv[2])
gitattributes (16) ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
map_from_diffusers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ ##
4
+ # Code from huggingface/twodgirl
5
+ # License: apache-2.0
6
+ #
7
+ # Reverse of the script from
8
+ # https://github.com/huggingface/diffusers/blob/main/scripts/convert_flux_to_diffusers.py
9
+
10
+ def swap_scale_shift(weight):
11
+ shift, scale = weight.chunk(2, dim=0)
12
+ new_weight = torch.cat([scale, shift], dim=0)
13
+
14
+ return new_weight
15
+
16
+ def convert_diffusers_to_flux_checkpoint(
17
+ converted_state_dict,
18
+ num_layers=19,
19
+ num_single_layers=38,
20
+ inner_dim=3072,
21
+ mlp_ratio=4.0
22
+ ):
23
+ """
24
+ 84c3df90-9df5-48c2-9fa0-1e81324e61bf
25
+
26
+ Reverses the conversion from Diffusers checkpoint to Flux Transformer format.
27
+
28
+ This function takes a state dictionary that has been converted to the Diffusers format
29
+ and transforms it back to the original Flux Transformer checkpoint format. It systematically
30
+ maps each parameter from the Diffusers naming and structure back to the original format,
31
+ handling different components such as embeddings, transformer blocks, and normalization layers.
32
+
33
+ Args:
34
+ converted_state_dict (dict): The state dictionary in Diffusers format to be converted back.
35
+ num_layers (int, optional): Number of transformer layers in the original model. Default is 19.
36
+ num_single_layers (int, optional): Number of single transformer layers. Default is 38.
37
+ inner_dim (int, optional): The inner dimension size for MLP layers. Default is 3072.
38
+ mlp_ratio (float, optional): The ratio to compute the MLP hidden dimension. Default is 4.0.
39
+
40
+ Returns:
41
+ dict: The original state dictionary in Flux Transformer checkpoint format.
42
+ """
43
+ # Initialize an empty dictionary to store the original state dictionary.
44
+ original_state_dict = {}
45
+
46
+ # -------------------------
47
+ # Handle Time Text Embeddings
48
+ # -------------------------
49
+
50
+ # Map the timestep embedder weights and biases back to "time_in.in_layer"
51
+ original_state_dict["time_in.in_layer.weight"] = converted_state_dict.pop(
52
+ "time_text_embed.timestep_embedder.linear_1.weight"
53
+ )
54
+ original_state_dict["time_in.in_layer.bias"] = converted_state_dict.pop(
55
+ "time_text_embed.timestep_embedder.linear_1.bias"
56
+ )
57
+ original_state_dict["time_in.out_layer.weight"] = converted_state_dict.pop(
58
+ "time_text_embed.timestep_embedder.linear_2.weight"
59
+ )
60
+ original_state_dict["time_in.out_layer.bias"] = converted_state_dict.pop(
61
+ "time_text_embed.timestep_embedder.linear_2.bias"
62
+ )
63
+
64
+ # Map the text embedder weights and biases back to "vector_in.in_layer"
65
+ original_state_dict["vector_in.in_layer.weight"] = converted_state_dict.pop(
66
+ "time_text_embed.text_embedder.linear_1.weight"
67
+ )
68
+ original_state_dict["vector_in.in_layer.bias"] = converted_state_dict.pop(
69
+ "time_text_embed.text_embedder.linear_1.bias"
70
+ )
71
+ original_state_dict["vector_in.out_layer.weight"] = converted_state_dict.pop(
72
+ "time_text_embed.text_embedder.linear_2.weight"
73
+ )
74
+ original_state_dict["vector_in.out_layer.bias"] = converted_state_dict.pop(
75
+ "time_text_embed.text_embedder.linear_2.bias"
76
+ )
77
+
78
+ # -------------------------
79
+ # Handle Guidance Embeddings (if present)
80
+ # -------------------------
81
+
82
+ # Check if any keys related to guidance are present in the converted_state_dict
83
+ has_guidance = any("guidance_embedder" in k for k in converted_state_dict)
84
+ if has_guidance:
85
+ # Map the guidance embedder weights and biases back to "guidance_in.in_layer"
86
+ original_state_dict["guidance_in.in_layer.weight"] = converted_state_dict.pop(
87
+ "time_text_embed.guidance_embedder.linear_1.weight"
88
+ )
89
+ original_state_dict["guidance_in.in_layer.bias"] = converted_state_dict.pop(
90
+ "time_text_embed.guidance_embedder.linear_1.bias"
91
+ )
92
+ original_state_dict["guidance_in.out_layer.weight"] = converted_state_dict.pop(
93
+ "time_text_embed.guidance_embedder.linear_2.weight"
94
+ )
95
+ original_state_dict["guidance_in.out_layer.bias"] = converted_state_dict.pop(
96
+ "time_text_embed.guidance_embedder.linear_2.bias"
97
+ )
98
+
99
+ # -------------------------
100
+ # Handle Context and Image Embeddings
101
+ # -------------------------
102
+
103
+ # Map the context embedder weights and biases back to "txt_in"
104
+ original_state_dict["txt_in.weight"] = converted_state_dict.pop("context_embedder.weight")
105
+ original_state_dict["txt_in.bias"] = converted_state_dict.pop("context_embedder.bias")
106
+
107
+ # Map the image embedder weights and biases back to "img_in"
108
+ original_state_dict["img_in.weight"] = converted_state_dict.pop("x_embedder.weight")
109
+ original_state_dict["img_in.bias"] = converted_state_dict.pop("x_embedder.bias")
110
+
111
+ # -------------------------
112
+ # Handle Transformer Blocks
113
+ # -------------------------
114
+
115
+ for i in range(num_layers):
116
+ # Define the prefix for the current transformer block in the converted_state_dict
117
+ block_prefix = f"transformer_blocks.{i}."
118
+
119
+ # -------------------------
120
+ # Map Norm1 Layers
121
+ # -------------------------
122
+
123
+ # Map the norm1 linear layer weights and biases back to "double_blocks.{i}.img_mod.lin"
124
+ original_state_dict[f"double_blocks.{i}.img_mod.lin.weight"] = converted_state_dict.pop(
125
+ f"{block_prefix}norm1.linear.weight"
126
+ )
127
+ original_state_dict[f"double_blocks.{i}.img_mod.lin.bias"] = converted_state_dict.pop(
128
+ f"{block_prefix}norm1.linear.bias"
129
+ )
130
+
131
+ # Map the norm1_context linear layer weights and biases back to "double_blocks.{i}.txt_mod.lin"
132
+ original_state_dict[f"double_blocks.{i}.txt_mod.lin.weight"] = converted_state_dict.pop(
133
+ f"{block_prefix}norm1_context.linear.weight"
134
+ )
135
+ original_state_dict[f"double_blocks.{i}.txt_mod.lin.bias"] = converted_state_dict.pop(
136
+ f"{block_prefix}norm1_context.linear.bias"
137
+ )
138
+
139
+ # -------------------------
140
+ # Handle Q, K, V Projections for Image Attention
141
+ # -------------------------
142
+
143
+ # Retrieve and combine the Q, K, V weights for image attention
144
+ q_weight = converted_state_dict.pop(f"{block_prefix}attn.to_q.weight")
145
+ k_weight = converted_state_dict.pop(f"{block_prefix}attn.to_k.weight")
146
+ v_weight = converted_state_dict.pop(f"{block_prefix}attn.to_v.weight")
147
+ # Concatenate along the first dimension to form the combined QKV weight
148
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.cat([q_weight, k_weight, v_weight], dim=0)
149
+
150
+ # Retrieve and combine the Q, K, V biases for image attention
151
+ q_bias = converted_state_dict.pop(f"{block_prefix}attn.to_q.bias")
152
+ k_bias = converted_state_dict.pop(f"{block_prefix}attn.to_k.bias")
153
+ v_bias = converted_state_dict.pop(f"{block_prefix}attn.to_v.bias")
154
+ # Concatenate along the first dimension to form the combined QKV bias
155
+ original_state_dict[f"double_blocks.{i}.img_attn.qkv.bias"] = torch.cat([q_bias, k_bias, v_bias], dim=0)
156
+
157
+ # -------------------------
158
+ # Handle Q, K, V Projections for Text Attention
159
+ # -------------------------
160
+
161
+ # Retrieve and combine the additional Q, K, V projections for context (text) attention
162
+ add_q_weight = converted_state_dict.pop(f"{block_prefix}attn.add_q_proj.weight")
163
+ add_k_weight = converted_state_dict.pop(f"{block_prefix}attn.add_k_proj.weight")
164
+ add_v_weight = converted_state_dict.pop(f"{block_prefix}attn.add_v_proj.weight")
165
+ # Concatenate along the first dimension to form the combined QKV weight for text
166
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.weight"] = torch.cat([add_q_weight, add_k_weight, add_v_weight], dim=0)
167
+
168
+ add_q_bias = converted_state_dict.pop(f"{block_prefix}attn.add_q_proj.bias")
169
+ add_k_bias = converted_state_dict.pop(f"{block_prefix}attn.add_k_proj.bias")
170
+ add_v_bias = converted_state_dict.pop(f"{block_prefix}attn.add_v_proj.bias")
171
+ # Concatenate along the first dimension to form the combined QKV bias for text
172
+ original_state_dict[f"double_blocks.{i}.txt_attn.qkv.bias"] = torch.cat([add_q_bias, add_k_bias, add_v_bias], dim=0)
173
+
174
+ # -------------------------
175
+ # Map Attention Norm Layers
176
+ # -------------------------
177
+
178
+ # Map the attention query norm weights back to "double_blocks.{i}.img_attn.norm.query_norm.scale"
179
+ original_state_dict[f"double_blocks.{i}.img_attn.norm.query_norm.scale"] = converted_state_dict.pop(
180
+ f"{block_prefix}attn.norm_q.weight"
181
+ )
182
+
183
+ # Map the attention key norm weights back to "double_blocks.{i}.img_attn.norm.key_norm.scale"
184
+ original_state_dict[f"double_blocks.{i}.img_attn.norm.key_norm.scale"] = converted_state_dict.pop(
185
+ f"{block_prefix}attn.norm_k.weight"
186
+ )
187
+
188
+ # Map the added attention query norm weights back to "double_blocks.{i}.txt_attn.norm.query_norm.scale"
189
+ original_state_dict[f"double_blocks.{i}.txt_attn.norm.query_norm.scale"] = converted_state_dict.pop(
190
+ f"{block_prefix}attn.norm_added_q.weight"
191
+ )
192
+
193
+ # Map the added attention key norm weights back to "double_blocks.{i}.txt_attn.norm.key_norm.scale"
194
+ original_state_dict[f"double_blocks.{i}.txt_attn.norm.key_norm.scale"] = converted_state_dict.pop(
195
+ f"{block_prefix}attn.norm_added_k.weight"
196
+ )
197
+
198
+ # -------------------------
199
+ # Handle Feed-Forward Networks (FFNs) for Image and Text
200
+ # -------------------------
201
+
202
+ # Map the image MLP projection layers back to "double_blocks.{i}.img_mlp"
203
+ original_state_dict[f"double_blocks.{i}.img_mlp.0.weight"] = converted_state_dict.pop(
204
+ f"{block_prefix}ff.net.0.proj.weight"
205
+ )
206
+ original_state_dict[f"double_blocks.{i}.img_mlp.0.bias"] = converted_state_dict.pop(
207
+ f"{block_prefix}ff.net.0.proj.bias"
208
+ )
209
+ original_state_dict[f"double_blocks.{i}.img_mlp.2.weight"] = converted_state_dict.pop(
210
+ f"{block_prefix}ff.net.2.weight"
211
+ )
212
+ original_state_dict[f"double_blocks.{i}.img_mlp.2.bias"] = converted_state_dict.pop(
213
+ f"{block_prefix}ff.net.2.bias"
214
+ )
215
+
216
+ # Map the text MLP projection layers back to "double_blocks.{i}.txt_mlp"
217
+ original_state_dict[f"double_blocks.{i}.txt_mlp.0.weight"] = converted_state_dict.pop(
218
+ f"{block_prefix}ff_context.net.0.proj.weight"
219
+ )
220
+ original_state_dict[f"double_blocks.{i}.txt_mlp.0.bias"] = converted_state_dict.pop(
221
+ f"{block_prefix}ff_context.net.0.proj.bias"
222
+ )
223
+ original_state_dict[f"double_blocks.{i}.txt_mlp.2.weight"] = converted_state_dict.pop(
224
+ f"{block_prefix}ff_context.net.2.weight"
225
+ )
226
+ original_state_dict[f"double_blocks.{i}.txt_mlp.2.bias"] = converted_state_dict.pop(
227
+ f"{block_prefix}ff_context.net.2.bias"
228
+ )
229
+
230
+ # -------------------------
231
+ # Handle Attention Output Projections
232
+ # -------------------------
233
+
234
+ # Map the image attention output projection weights and biases back to "double_blocks.{i}.img_attn.proj"
235
+ original_state_dict[f"double_blocks.{i}.img_attn.proj.weight"] = converted_state_dict.pop(
236
+ f"{block_prefix}attn.to_out.0.weight"
237
+ )
238
+ original_state_dict[f"double_blocks.{i}.img_attn.proj.bias"] = converted_state_dict.pop(
239
+ f"{block_prefix}attn.to_out.0.bias"
240
+ )
241
+
242
+ # Map the text attention output projection weights and biases back to "double_blocks.{i}.txt_attn.proj"
243
+ original_state_dict[f"double_blocks.{i}.txt_attn.proj.weight"] = converted_state_dict.pop(
244
+ f"{block_prefix}attn.to_add_out.weight"
245
+ )
246
+ original_state_dict[f"double_blocks.{i}.txt_attn.proj.bias"] = converted_state_dict.pop(
247
+ f"{block_prefix}attn.to_add_out.bias"
248
+ )
249
+
250
+ # -------------------------
251
+ # Handle Single Transformer Blocks
252
+ # -------------------------
253
+
254
+ for i in range(num_single_layers):
255
+ # Define the prefix for the current single transformer block in the converted_state_dict
256
+ block_prefix = f"single_transformer_blocks.{i}."
257
+
258
+ # -------------------------
259
+ # Map Norm Layers
260
+ # -------------------------
261
+
262
+ # Map the normalization linear layer weights and biases back to "single_blocks.{i}.modulation.lin"
263
+ original_state_dict[f"single_blocks.{i}.modulation.lin.weight"] = converted_state_dict.pop(
264
+ f"{block_prefix}norm.linear.weight"
265
+ )
266
+ original_state_dict[f"single_blocks.{i}.modulation.lin.bias"] = converted_state_dict.pop(
267
+ f"{block_prefix}norm.linear.bias"
268
+ )
269
+
270
+ # -------------------------
271
+ # Handle Q, K, V Projections and MLP
272
+ # -------------------------
273
+
274
+ # Retrieve the Q, K, V weights and the MLP projection weight
275
+ q_weight = converted_state_dict.pop(f"{block_prefix}attn.to_q.weight")
276
+ k_weight = converted_state_dict.pop(f"{block_prefix}attn.to_k.weight")
277
+ v_weight = converted_state_dict.pop(f"{block_prefix}attn.to_v.weight")
278
+ proj_mlp_weight = converted_state_dict.pop(f"{block_prefix}proj_mlp.weight")
279
+
280
+ # Concatenate Q, K, V, and MLP weights to form the combined linear1.weight
281
+ combined_weight = torch.cat([q_weight, k_weight, v_weight, proj_mlp_weight], dim=0)
282
+ original_state_dict[f"single_blocks.{i}.linear1.weight"] = combined_weight
283
+
284
+ # Retrieve the Q, K, V biases and the MLP projection bias
285
+ q_bias = converted_state_dict.pop(f"{block_prefix}attn.to_q.bias")
286
+ k_bias = converted_state_dict.pop(f"{block_prefix}attn.to_k.bias")
287
+ v_bias = converted_state_dict.pop(f"{block_prefix}attn.to_v.bias")
288
+ proj_mlp_bias = converted_state_dict.pop(f"{block_prefix}proj_mlp.bias")
289
+
290
+ # Concatenate Q, K, V, and MLP biases to form the combined linear1.bias
291
+ combined_bias = torch.cat([q_bias, k_bias, v_bias, proj_mlp_bias], dim=0)
292
+ original_state_dict[f"single_blocks.{i}.linear1.bias"] = combined_bias
293
+
294
+ # -------------------------
295
+ # Map Attention Normalization Weights
296
+ # -------------------------
297
+
298
+ # Map the attention query norm weights back to "single_blocks.{i}.norm.query_norm.scale"
299
+ original_state_dict[f"single_blocks.{i}.norm.query_norm.scale"] = converted_state_dict.pop(
300
+ f"{block_prefix}attn.norm_q.weight"
301
+ )
302
+
303
+ # Map the attention key norm weights back to "single_blocks.{i}.norm.key_norm.scale"
304
+ original_state_dict[f"single_blocks.{i}.norm.key_norm.scale"] = converted_state_dict.pop(
305
+ f"{block_prefix}attn.norm_k.weight"
306
+ )
307
+
308
+ # -------------------------
309
+ # Handle Projection Output
310
+ # -------------------------
311
+
312
+ # Map the projection output weights and biases back to "single_blocks.{i}.linear2"
313
+ original_state_dict[f"single_blocks.{i}.linear2.weight"] = converted_state_dict.pop(
314
+ f"{block_prefix}proj_out.weight"
315
+ )
316
+ original_state_dict[f"single_blocks.{i}.linear2.bias"] = converted_state_dict.pop(
317
+ f"{block_prefix}proj_out.bias"
318
+ )
319
+
320
+ # -------------------------
321
+ # Handle Final Output Projection and Normalization
322
+ # -------------------------
323
+
324
+ # Map the final output projection weights and biases back to "final_layer.linear"
325
+ original_state_dict["final_layer.linear.weight"] = converted_state_dict.pop("proj_out.weight")
326
+ original_state_dict["final_layer.linear.bias"] = converted_state_dict.pop("proj_out.bias")
327
+
328
+ # Reverse the swap_scale_shift transformation for normalization weights and biases
329
+ original_state_dict["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
330
+ converted_state_dict.pop("norm_out.linear.weight")
331
+ )
332
+ original_state_dict["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
333
+ converted_state_dict.pop("norm_out.linear.bias")
334
+ )
335
+
336
+ # -------------------------
337
+ # Handle Remaining Parameters (if any)
338
+ # -------------------------
339
+
340
+ # It's possible that there are remaining parameters that were not mapped.
341
+ # Depending on your use case, you can handle them here or raise an error.
342
+ if len(converted_state_dict) > 0:
343
+ # For debugging purposes, you might want to log or print the remaining keys
344
+ remaining_keys = list(converted_state_dict.keys())
345
+ print(f"Warning: The following keys were not mapped and remain in the state dict: {remaining_keys}")
346
+ # Optionally, you can choose to include them or exclude them from the original_state_dict
347
+
348
+ return original_state_dict
map_streamer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import struct
3
+ import torch
4
+ import threading
5
+ import warnings
6
+
7
+ ###
8
+ # Code from ljleb/sd-mecha/sd_mecha/streaming.py
9
+
10
+ DTYPE_MAPPING = {
11
+ 'F64': (torch.float64, 8),
12
+ 'F32': (torch.float32, 4),
13
+ 'F16': (torch.float16, 2),
14
+ 'BF16': (torch.bfloat16, 2),
15
+ 'I8': (torch.int8, 1),
16
+ 'I64': (torch.int64, 8),
17
+ 'I32': (torch.int32, 4),
18
+ 'I16': (torch.int16, 2),
19
+ "F8_E4M3": (torch.float8_e4m3fn, 1),
20
+ "F8_E5M2": (torch.float8_e5m2, 1),
21
+ }
22
+
23
+ class InSafetensorsDict:
24
+ def __init__(self, f, buffer_size):
25
+ self.default_buffer_size = buffer_size
26
+ self.file = f
27
+ self.header_size, self.header = self._read_header()
28
+ self.buffer = bytearray()
29
+ self.buffer_start_offset = 8 + self.header_size
30
+ self.lock = threading.Lock()
31
+
32
+ def __del__(self):
33
+ self.close()
34
+
35
+ def __getitem__(self, key):
36
+ if key not in self.header or key == "__metadata__":
37
+ raise KeyError(key)
38
+ return self._load_tensor(key)
39
+
40
+ def __iter__(self):
41
+ return iter(self.keys())
42
+
43
+ def __len__(self):
44
+ return len(self.header)
45
+
46
+ def close(self):
47
+ self.file.close()
48
+ self.buffer = None
49
+ self.header = None
50
+
51
+ def keys(self):
52
+ return (
53
+ key
54
+ for key in self.header.keys()
55
+ if key != "__metadata__"
56
+ )
57
+
58
+ def values(self):
59
+ for key in self.keys():
60
+ yield self[key]
61
+
62
+ def items(self):
63
+ for key in self.keys():
64
+ yield key, self[key]
65
+
66
+ def _read_header(self):
67
+ header_size_bytes = self.file.read(8)
68
+ header_size = struct.unpack('<Q', header_size_bytes)[0]
69
+ header_json = self.file.read(header_size).decode('utf-8').strip()
70
+ header = json.loads(header_json)
71
+
72
+ # sort by memory order to reduce seek time
73
+ sorted_header = dict(sorted(header.items(), key=lambda item: item[1].get('data_offsets', [0])[0]))
74
+ return header_size, sorted_header
75
+
76
+ def _ensure_buffer(self, start_pos, length):
77
+ if start_pos < self.buffer_start_offset or start_pos + length > self.buffer_start_offset + len(self.buffer):
78
+ self.file.seek(start_pos)
79
+ necessary_buffer_size = max(self.default_buffer_size, length)
80
+ if len(self.buffer) < necessary_buffer_size:
81
+ self.buffer = bytearray(necessary_buffer_size)
82
+ else:
83
+ self.buffer = self.buffer[:necessary_buffer_size]
84
+
85
+ self.file.readinto(self.buffer)
86
+ self.buffer_start_offset = start_pos
87
+
88
+ def _load_tensor(self, tensor_name):
89
+ tensor_info = self.header[tensor_name]
90
+ offsets = tensor_info['data_offsets']
91
+ dtype, dtype_bytes = DTYPE_MAPPING[tensor_info['dtype']]
92
+ shape = tensor_info['shape']
93
+ total_bytes = offsets[1] - offsets[0]
94
+ absolute_start_pos = 8 + self.header_size + offsets[0]
95
+ with warnings.catch_warnings():
96
+ warnings.simplefilter('ignore')
97
+ with self.lock:
98
+ self._ensure_buffer(absolute_start_pos, total_bytes)
99
+ buffer_offset = absolute_start_pos - self.buffer_start_offset
100
+ return torch.frombuffer(self.buffer, count=total_bytes // dtype_bytes, offset=buffer_offset, dtype=dtype).reshape(shape)
101
+