Upload 5 files
Browse files- README (20).md +64 -0
- flux_devpro_ckpt.py +16 -0
- gitattributes (16) +35 -0
- map_from_diffusers.py +348 -0
- map_streamer.py +101 -0
README (20).md
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: other
|
3 |
+
license_name: flux-1-dev-non-commercial-license
|
4 |
+
license_link: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md
|
5 |
+
base_model:
|
6 |
+
- ashen0209/Flux-Dev2Pro
|
7 |
+
- black-forest-labs/FLUX.1-schnell
|
8 |
+
tags:
|
9 |
+
- text-to-image
|
10 |
+
- flux
|
11 |
+
- merge
|
12 |
+
widget:
|
13 |
+
- text: Example 1
|
14 |
+
output:
|
15 |
+
url: images/0015.webp
|
16 |
+
- text: Example 2
|
17 |
+
output:
|
18 |
+
url: images/0022.webp
|
19 |
+
- text: Example 3
|
20 |
+
output:
|
21 |
+
url: images/0012.webp
|
22 |
+
- text: Example 4
|
23 |
+
output:
|
24 |
+
url: images/0014.webp
|
25 |
+
- text: Example 5
|
26 |
+
output:
|
27 |
+
url: images/0033.webp
|
28 |
+
- text: Example 6
|
29 |
+
output:
|
30 |
+
url: images/0032.webp
|
31 |
+
---
|
32 |
+
|
33 |
+
# Flux D+S F8 Diffusers
|
34 |
+
|
35 |
+
A + B merge, meant for gradio inference, possibly build adapters on top with [diffusers](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py).
|
36 |
+
|
37 |
+
<Gallery />
|
38 |
+
|
39 |
+
## Inference
|
40 |
+
|
41 |
+
[Instruction](https://huggingface.co/twodgirl/flux-dev-fp8-e4m3fn-diffusers)
|
42 |
+
|
43 |
+
This model needs 10 steps, otherwise the images remain blurry.
|
44 |
+
|
45 |
+
## ComfyUI
|
46 |
+
|
47 |
+
I convert all my checkpoints locally. There are 5+ file formats out there, they take up too much space.
|
48 |
+
|
49 |
+
Download the contents of this repo, then enter the directory. A script will convert the diffusers format to something that is compatible with ComfyUI.
|
50 |
+
|
51 |
+
```
|
52 |
+
cd flux-devpro-schnell-merge-fp8-e4m3fn-diffusers
|
53 |
+
pip install safetensors torch
|
54 |
+
python flux_devpro_ckpt.py transformer/diffusion_pytorch_model.safetensors flux-devpro-schnell.safetensors
|
55 |
+
mv flux-devpro-schnell.safetensors path/to/your/ComfyUI/models/unet/flux-devpro-schnell.safetensors
|
56 |
+
```
|
57 |
+
|
58 |
+
In the workflow, the LoadDiffusionModel node should take two values: flux-devpro-schnell.safetensors, fp8_e4m3fn.
|
59 |
+
|
60 |
+
Use of this code requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.
|
61 |
+
|
62 |
+
## Disclaimer
|
63 |
+
|
64 |
+
Sharing, reusing the model weights requires a link back to authors and their Hugging Face profile. The source models were uploaded by Ashen0209 and BFL.
|
flux_devpro_ckpt.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
from map_from_diffusers import convert_diffusers_to_flux_checkpoint
|
3 |
+
from safetensors.torch import load_file, save_file
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
|
7 |
+
###
|
8 |
+
# Code from huggingface/twodgirl
|
9 |
+
# License: apache-2.0
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
sd = convert_diffusers_to_flux_checkpoint(load_file(sys.argv[1]))
|
13 |
+
assert sd['time_in.in_layer.weight'].dtype == torch.float8_e4m3fn
|
14 |
+
print(len(sd))
|
15 |
+
gc.collect()
|
16 |
+
save_file(sd, sys.argv[2])
|
gitattributes (16)
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
map_from_diffusers.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
##
|
4 |
+
# Code from huggingface/twodgirl
|
5 |
+
# License: apache-2.0
|
6 |
+
#
|
7 |
+
# Reverse of the script from
|
8 |
+
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_flux_to_diffusers.py
|
9 |
+
|
10 |
+
def swap_scale_shift(weight):
|
11 |
+
shift, scale = weight.chunk(2, dim=0)
|
12 |
+
new_weight = torch.cat([scale, shift], dim=0)
|
13 |
+
|
14 |
+
return new_weight
|
15 |
+
|
16 |
+
def convert_diffusers_to_flux_checkpoint(
|
17 |
+
converted_state_dict,
|
18 |
+
num_layers=19,
|
19 |
+
num_single_layers=38,
|
20 |
+
inner_dim=3072,
|
21 |
+
mlp_ratio=4.0
|
22 |
+
):
|
23 |
+
"""
|
24 |
+
84c3df90-9df5-48c2-9fa0-1e81324e61bf
|
25 |
+
|
26 |
+
Reverses the conversion from Diffusers checkpoint to Flux Transformer format.
|
27 |
+
|
28 |
+
This function takes a state dictionary that has been converted to the Diffusers format
|
29 |
+
and transforms it back to the original Flux Transformer checkpoint format. It systematically
|
30 |
+
maps each parameter from the Diffusers naming and structure back to the original format,
|
31 |
+
handling different components such as embeddings, transformer blocks, and normalization layers.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
converted_state_dict (dict): The state dictionary in Diffusers format to be converted back.
|
35 |
+
num_layers (int, optional): Number of transformer layers in the original model. Default is 19.
|
36 |
+
num_single_layers (int, optional): Number of single transformer layers. Default is 38.
|
37 |
+
inner_dim (int, optional): The inner dimension size for MLP layers. Default is 3072.
|
38 |
+
mlp_ratio (float, optional): The ratio to compute the MLP hidden dimension. Default is 4.0.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
dict: The original state dictionary in Flux Transformer checkpoint format.
|
42 |
+
"""
|
43 |
+
# Initialize an empty dictionary to store the original state dictionary.
|
44 |
+
original_state_dict = {}
|
45 |
+
|
46 |
+
# -------------------------
|
47 |
+
# Handle Time Text Embeddings
|
48 |
+
# -------------------------
|
49 |
+
|
50 |
+
# Map the timestep embedder weights and biases back to "time_in.in_layer"
|
51 |
+
original_state_dict["time_in.in_layer.weight"] = converted_state_dict.pop(
|
52 |
+
"time_text_embed.timestep_embedder.linear_1.weight"
|
53 |
+
)
|
54 |
+
original_state_dict["time_in.in_layer.bias"] = converted_state_dict.pop(
|
55 |
+
"time_text_embed.timestep_embedder.linear_1.bias"
|
56 |
+
)
|
57 |
+
original_state_dict["time_in.out_layer.weight"] = converted_state_dict.pop(
|
58 |
+
"time_text_embed.timestep_embedder.linear_2.weight"
|
59 |
+
)
|
60 |
+
original_state_dict["time_in.out_layer.bias"] = converted_state_dict.pop(
|
61 |
+
"time_text_embed.timestep_embedder.linear_2.bias"
|
62 |
+
)
|
63 |
+
|
64 |
+
# Map the text embedder weights and biases back to "vector_in.in_layer"
|
65 |
+
original_state_dict["vector_in.in_layer.weight"] = converted_state_dict.pop(
|
66 |
+
"time_text_embed.text_embedder.linear_1.weight"
|
67 |
+
)
|
68 |
+
original_state_dict["vector_in.in_layer.bias"] = converted_state_dict.pop(
|
69 |
+
"time_text_embed.text_embedder.linear_1.bias"
|
70 |
+
)
|
71 |
+
original_state_dict["vector_in.out_layer.weight"] = converted_state_dict.pop(
|
72 |
+
"time_text_embed.text_embedder.linear_2.weight"
|
73 |
+
)
|
74 |
+
original_state_dict["vector_in.out_layer.bias"] = converted_state_dict.pop(
|
75 |
+
"time_text_embed.text_embedder.linear_2.bias"
|
76 |
+
)
|
77 |
+
|
78 |
+
# -------------------------
|
79 |
+
# Handle Guidance Embeddings (if present)
|
80 |
+
# -------------------------
|
81 |
+
|
82 |
+
# Check if any keys related to guidance are present in the converted_state_dict
|
83 |
+
has_guidance = any("guidance_embedder" in k for k in converted_state_dict)
|
84 |
+
if has_guidance:
|
85 |
+
# Map the guidance embedder weights and biases back to "guidance_in.in_layer"
|
86 |
+
original_state_dict["guidance_in.in_layer.weight"] = converted_state_dict.pop(
|
87 |
+
"time_text_embed.guidance_embedder.linear_1.weight"
|
88 |
+
)
|
89 |
+
original_state_dict["guidance_in.in_layer.bias"] = converted_state_dict.pop(
|
90 |
+
"time_text_embed.guidance_embedder.linear_1.bias"
|
91 |
+
)
|
92 |
+
original_state_dict["guidance_in.out_layer.weight"] = converted_state_dict.pop(
|
93 |
+
"time_text_embed.guidance_embedder.linear_2.weight"
|
94 |
+
)
|
95 |
+
original_state_dict["guidance_in.out_layer.bias"] = converted_state_dict.pop(
|
96 |
+
"time_text_embed.guidance_embedder.linear_2.bias"
|
97 |
+
)
|
98 |
+
|
99 |
+
# -------------------------
|
100 |
+
# Handle Context and Image Embeddings
|
101 |
+
# -------------------------
|
102 |
+
|
103 |
+
# Map the context embedder weights and biases back to "txt_in"
|
104 |
+
original_state_dict["txt_in.weight"] = converted_state_dict.pop("context_embedder.weight")
|
105 |
+
original_state_dict["txt_in.bias"] = converted_state_dict.pop("context_embedder.bias")
|
106 |
+
|
107 |
+
# Map the image embedder weights and biases back to "img_in"
|
108 |
+
original_state_dict["img_in.weight"] = converted_state_dict.pop("x_embedder.weight")
|
109 |
+
original_state_dict["img_in.bias"] = converted_state_dict.pop("x_embedder.bias")
|
110 |
+
|
111 |
+
# -------------------------
|
112 |
+
# Handle Transformer Blocks
|
113 |
+
# -------------------------
|
114 |
+
|
115 |
+
for i in range(num_layers):
|
116 |
+
# Define the prefix for the current transformer block in the converted_state_dict
|
117 |
+
block_prefix = f"transformer_blocks.{i}."
|
118 |
+
|
119 |
+
# -------------------------
|
120 |
+
# Map Norm1 Layers
|
121 |
+
# -------------------------
|
122 |
+
|
123 |
+
# Map the norm1 linear layer weights and biases back to "double_blocks.{i}.img_mod.lin"
|
124 |
+
original_state_dict[f"double_blocks.{i}.img_mod.lin.weight"] = converted_state_dict.pop(
|
125 |
+
f"{block_prefix}norm1.linear.weight"
|
126 |
+
)
|
127 |
+
original_state_dict[f"double_blocks.{i}.img_mod.lin.bias"] = converted_state_dict.pop(
|
128 |
+
f"{block_prefix}norm1.linear.bias"
|
129 |
+
)
|
130 |
+
|
131 |
+
# Map the norm1_context linear layer weights and biases back to "double_blocks.{i}.txt_mod.lin"
|
132 |
+
original_state_dict[f"double_blocks.{i}.txt_mod.lin.weight"] = converted_state_dict.pop(
|
133 |
+
f"{block_prefix}norm1_context.linear.weight"
|
134 |
+
)
|
135 |
+
original_state_dict[f"double_blocks.{i}.txt_mod.lin.bias"] = converted_state_dict.pop(
|
136 |
+
f"{block_prefix}norm1_context.linear.bias"
|
137 |
+
)
|
138 |
+
|
139 |
+
# -------------------------
|
140 |
+
# Handle Q, K, V Projections for Image Attention
|
141 |
+
# -------------------------
|
142 |
+
|
143 |
+
# Retrieve and combine the Q, K, V weights for image attention
|
144 |
+
q_weight = converted_state_dict.pop(f"{block_prefix}attn.to_q.weight")
|
145 |
+
k_weight = converted_state_dict.pop(f"{block_prefix}attn.to_k.weight")
|
146 |
+
v_weight = converted_state_dict.pop(f"{block_prefix}attn.to_v.weight")
|
147 |
+
# Concatenate along the first dimension to form the combined QKV weight
|
148 |
+
original_state_dict[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.cat([q_weight, k_weight, v_weight], dim=0)
|
149 |
+
|
150 |
+
# Retrieve and combine the Q, K, V biases for image attention
|
151 |
+
q_bias = converted_state_dict.pop(f"{block_prefix}attn.to_q.bias")
|
152 |
+
k_bias = converted_state_dict.pop(f"{block_prefix}attn.to_k.bias")
|
153 |
+
v_bias = converted_state_dict.pop(f"{block_prefix}attn.to_v.bias")
|
154 |
+
# Concatenate along the first dimension to form the combined QKV bias
|
155 |
+
original_state_dict[f"double_blocks.{i}.img_attn.qkv.bias"] = torch.cat([q_bias, k_bias, v_bias], dim=0)
|
156 |
+
|
157 |
+
# -------------------------
|
158 |
+
# Handle Q, K, V Projections for Text Attention
|
159 |
+
# -------------------------
|
160 |
+
|
161 |
+
# Retrieve and combine the additional Q, K, V projections for context (text) attention
|
162 |
+
add_q_weight = converted_state_dict.pop(f"{block_prefix}attn.add_q_proj.weight")
|
163 |
+
add_k_weight = converted_state_dict.pop(f"{block_prefix}attn.add_k_proj.weight")
|
164 |
+
add_v_weight = converted_state_dict.pop(f"{block_prefix}attn.add_v_proj.weight")
|
165 |
+
# Concatenate along the first dimension to form the combined QKV weight for text
|
166 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.qkv.weight"] = torch.cat([add_q_weight, add_k_weight, add_v_weight], dim=0)
|
167 |
+
|
168 |
+
add_q_bias = converted_state_dict.pop(f"{block_prefix}attn.add_q_proj.bias")
|
169 |
+
add_k_bias = converted_state_dict.pop(f"{block_prefix}attn.add_k_proj.bias")
|
170 |
+
add_v_bias = converted_state_dict.pop(f"{block_prefix}attn.add_v_proj.bias")
|
171 |
+
# Concatenate along the first dimension to form the combined QKV bias for text
|
172 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.qkv.bias"] = torch.cat([add_q_bias, add_k_bias, add_v_bias], dim=0)
|
173 |
+
|
174 |
+
# -------------------------
|
175 |
+
# Map Attention Norm Layers
|
176 |
+
# -------------------------
|
177 |
+
|
178 |
+
# Map the attention query norm weights back to "double_blocks.{i}.img_attn.norm.query_norm.scale"
|
179 |
+
original_state_dict[f"double_blocks.{i}.img_attn.norm.query_norm.scale"] = converted_state_dict.pop(
|
180 |
+
f"{block_prefix}attn.norm_q.weight"
|
181 |
+
)
|
182 |
+
|
183 |
+
# Map the attention key norm weights back to "double_blocks.{i}.img_attn.norm.key_norm.scale"
|
184 |
+
original_state_dict[f"double_blocks.{i}.img_attn.norm.key_norm.scale"] = converted_state_dict.pop(
|
185 |
+
f"{block_prefix}attn.norm_k.weight"
|
186 |
+
)
|
187 |
+
|
188 |
+
# Map the added attention query norm weights back to "double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
189 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.norm.query_norm.scale"] = converted_state_dict.pop(
|
190 |
+
f"{block_prefix}attn.norm_added_q.weight"
|
191 |
+
)
|
192 |
+
|
193 |
+
# Map the added attention key norm weights back to "double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
194 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.norm.key_norm.scale"] = converted_state_dict.pop(
|
195 |
+
f"{block_prefix}attn.norm_added_k.weight"
|
196 |
+
)
|
197 |
+
|
198 |
+
# -------------------------
|
199 |
+
# Handle Feed-Forward Networks (FFNs) for Image and Text
|
200 |
+
# -------------------------
|
201 |
+
|
202 |
+
# Map the image MLP projection layers back to "double_blocks.{i}.img_mlp"
|
203 |
+
original_state_dict[f"double_blocks.{i}.img_mlp.0.weight"] = converted_state_dict.pop(
|
204 |
+
f"{block_prefix}ff.net.0.proj.weight"
|
205 |
+
)
|
206 |
+
original_state_dict[f"double_blocks.{i}.img_mlp.0.bias"] = converted_state_dict.pop(
|
207 |
+
f"{block_prefix}ff.net.0.proj.bias"
|
208 |
+
)
|
209 |
+
original_state_dict[f"double_blocks.{i}.img_mlp.2.weight"] = converted_state_dict.pop(
|
210 |
+
f"{block_prefix}ff.net.2.weight"
|
211 |
+
)
|
212 |
+
original_state_dict[f"double_blocks.{i}.img_mlp.2.bias"] = converted_state_dict.pop(
|
213 |
+
f"{block_prefix}ff.net.2.bias"
|
214 |
+
)
|
215 |
+
|
216 |
+
# Map the text MLP projection layers back to "double_blocks.{i}.txt_mlp"
|
217 |
+
original_state_dict[f"double_blocks.{i}.txt_mlp.0.weight"] = converted_state_dict.pop(
|
218 |
+
f"{block_prefix}ff_context.net.0.proj.weight"
|
219 |
+
)
|
220 |
+
original_state_dict[f"double_blocks.{i}.txt_mlp.0.bias"] = converted_state_dict.pop(
|
221 |
+
f"{block_prefix}ff_context.net.0.proj.bias"
|
222 |
+
)
|
223 |
+
original_state_dict[f"double_blocks.{i}.txt_mlp.2.weight"] = converted_state_dict.pop(
|
224 |
+
f"{block_prefix}ff_context.net.2.weight"
|
225 |
+
)
|
226 |
+
original_state_dict[f"double_blocks.{i}.txt_mlp.2.bias"] = converted_state_dict.pop(
|
227 |
+
f"{block_prefix}ff_context.net.2.bias"
|
228 |
+
)
|
229 |
+
|
230 |
+
# -------------------------
|
231 |
+
# Handle Attention Output Projections
|
232 |
+
# -------------------------
|
233 |
+
|
234 |
+
# Map the image attention output projection weights and biases back to "double_blocks.{i}.img_attn.proj"
|
235 |
+
original_state_dict[f"double_blocks.{i}.img_attn.proj.weight"] = converted_state_dict.pop(
|
236 |
+
f"{block_prefix}attn.to_out.0.weight"
|
237 |
+
)
|
238 |
+
original_state_dict[f"double_blocks.{i}.img_attn.proj.bias"] = converted_state_dict.pop(
|
239 |
+
f"{block_prefix}attn.to_out.0.bias"
|
240 |
+
)
|
241 |
+
|
242 |
+
# Map the text attention output projection weights and biases back to "double_blocks.{i}.txt_attn.proj"
|
243 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.proj.weight"] = converted_state_dict.pop(
|
244 |
+
f"{block_prefix}attn.to_add_out.weight"
|
245 |
+
)
|
246 |
+
original_state_dict[f"double_blocks.{i}.txt_attn.proj.bias"] = converted_state_dict.pop(
|
247 |
+
f"{block_prefix}attn.to_add_out.bias"
|
248 |
+
)
|
249 |
+
|
250 |
+
# -------------------------
|
251 |
+
# Handle Single Transformer Blocks
|
252 |
+
# -------------------------
|
253 |
+
|
254 |
+
for i in range(num_single_layers):
|
255 |
+
# Define the prefix for the current single transformer block in the converted_state_dict
|
256 |
+
block_prefix = f"single_transformer_blocks.{i}."
|
257 |
+
|
258 |
+
# -------------------------
|
259 |
+
# Map Norm Layers
|
260 |
+
# -------------------------
|
261 |
+
|
262 |
+
# Map the normalization linear layer weights and biases back to "single_blocks.{i}.modulation.lin"
|
263 |
+
original_state_dict[f"single_blocks.{i}.modulation.lin.weight"] = converted_state_dict.pop(
|
264 |
+
f"{block_prefix}norm.linear.weight"
|
265 |
+
)
|
266 |
+
original_state_dict[f"single_blocks.{i}.modulation.lin.bias"] = converted_state_dict.pop(
|
267 |
+
f"{block_prefix}norm.linear.bias"
|
268 |
+
)
|
269 |
+
|
270 |
+
# -------------------------
|
271 |
+
# Handle Q, K, V Projections and MLP
|
272 |
+
# -------------------------
|
273 |
+
|
274 |
+
# Retrieve the Q, K, V weights and the MLP projection weight
|
275 |
+
q_weight = converted_state_dict.pop(f"{block_prefix}attn.to_q.weight")
|
276 |
+
k_weight = converted_state_dict.pop(f"{block_prefix}attn.to_k.weight")
|
277 |
+
v_weight = converted_state_dict.pop(f"{block_prefix}attn.to_v.weight")
|
278 |
+
proj_mlp_weight = converted_state_dict.pop(f"{block_prefix}proj_mlp.weight")
|
279 |
+
|
280 |
+
# Concatenate Q, K, V, and MLP weights to form the combined linear1.weight
|
281 |
+
combined_weight = torch.cat([q_weight, k_weight, v_weight, proj_mlp_weight], dim=0)
|
282 |
+
original_state_dict[f"single_blocks.{i}.linear1.weight"] = combined_weight
|
283 |
+
|
284 |
+
# Retrieve the Q, K, V biases and the MLP projection bias
|
285 |
+
q_bias = converted_state_dict.pop(f"{block_prefix}attn.to_q.bias")
|
286 |
+
k_bias = converted_state_dict.pop(f"{block_prefix}attn.to_k.bias")
|
287 |
+
v_bias = converted_state_dict.pop(f"{block_prefix}attn.to_v.bias")
|
288 |
+
proj_mlp_bias = converted_state_dict.pop(f"{block_prefix}proj_mlp.bias")
|
289 |
+
|
290 |
+
# Concatenate Q, K, V, and MLP biases to form the combined linear1.bias
|
291 |
+
combined_bias = torch.cat([q_bias, k_bias, v_bias, proj_mlp_bias], dim=0)
|
292 |
+
original_state_dict[f"single_blocks.{i}.linear1.bias"] = combined_bias
|
293 |
+
|
294 |
+
# -------------------------
|
295 |
+
# Map Attention Normalization Weights
|
296 |
+
# -------------------------
|
297 |
+
|
298 |
+
# Map the attention query norm weights back to "single_blocks.{i}.norm.query_norm.scale"
|
299 |
+
original_state_dict[f"single_blocks.{i}.norm.query_norm.scale"] = converted_state_dict.pop(
|
300 |
+
f"{block_prefix}attn.norm_q.weight"
|
301 |
+
)
|
302 |
+
|
303 |
+
# Map the attention key norm weights back to "single_blocks.{i}.norm.key_norm.scale"
|
304 |
+
original_state_dict[f"single_blocks.{i}.norm.key_norm.scale"] = converted_state_dict.pop(
|
305 |
+
f"{block_prefix}attn.norm_k.weight"
|
306 |
+
)
|
307 |
+
|
308 |
+
# -------------------------
|
309 |
+
# Handle Projection Output
|
310 |
+
# -------------------------
|
311 |
+
|
312 |
+
# Map the projection output weights and biases back to "single_blocks.{i}.linear2"
|
313 |
+
original_state_dict[f"single_blocks.{i}.linear2.weight"] = converted_state_dict.pop(
|
314 |
+
f"{block_prefix}proj_out.weight"
|
315 |
+
)
|
316 |
+
original_state_dict[f"single_blocks.{i}.linear2.bias"] = converted_state_dict.pop(
|
317 |
+
f"{block_prefix}proj_out.bias"
|
318 |
+
)
|
319 |
+
|
320 |
+
# -------------------------
|
321 |
+
# Handle Final Output Projection and Normalization
|
322 |
+
# -------------------------
|
323 |
+
|
324 |
+
# Map the final output projection weights and biases back to "final_layer.linear"
|
325 |
+
original_state_dict["final_layer.linear.weight"] = converted_state_dict.pop("proj_out.weight")
|
326 |
+
original_state_dict["final_layer.linear.bias"] = converted_state_dict.pop("proj_out.bias")
|
327 |
+
|
328 |
+
# Reverse the swap_scale_shift transformation for normalization weights and biases
|
329 |
+
original_state_dict["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
|
330 |
+
converted_state_dict.pop("norm_out.linear.weight")
|
331 |
+
)
|
332 |
+
original_state_dict["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
|
333 |
+
converted_state_dict.pop("norm_out.linear.bias")
|
334 |
+
)
|
335 |
+
|
336 |
+
# -------------------------
|
337 |
+
# Handle Remaining Parameters (if any)
|
338 |
+
# -------------------------
|
339 |
+
|
340 |
+
# It's possible that there are remaining parameters that were not mapped.
|
341 |
+
# Depending on your use case, you can handle them here or raise an error.
|
342 |
+
if len(converted_state_dict) > 0:
|
343 |
+
# For debugging purposes, you might want to log or print the remaining keys
|
344 |
+
remaining_keys = list(converted_state_dict.keys())
|
345 |
+
print(f"Warning: The following keys were not mapped and remain in the state dict: {remaining_keys}")
|
346 |
+
# Optionally, you can choose to include them or exclude them from the original_state_dict
|
347 |
+
|
348 |
+
return original_state_dict
|
map_streamer.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import struct
|
3 |
+
import torch
|
4 |
+
import threading
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
###
|
8 |
+
# Code from ljleb/sd-mecha/sd_mecha/streaming.py
|
9 |
+
|
10 |
+
DTYPE_MAPPING = {
|
11 |
+
'F64': (torch.float64, 8),
|
12 |
+
'F32': (torch.float32, 4),
|
13 |
+
'F16': (torch.float16, 2),
|
14 |
+
'BF16': (torch.bfloat16, 2),
|
15 |
+
'I8': (torch.int8, 1),
|
16 |
+
'I64': (torch.int64, 8),
|
17 |
+
'I32': (torch.int32, 4),
|
18 |
+
'I16': (torch.int16, 2),
|
19 |
+
"F8_E4M3": (torch.float8_e4m3fn, 1),
|
20 |
+
"F8_E5M2": (torch.float8_e5m2, 1),
|
21 |
+
}
|
22 |
+
|
23 |
+
class InSafetensorsDict:
|
24 |
+
def __init__(self, f, buffer_size):
|
25 |
+
self.default_buffer_size = buffer_size
|
26 |
+
self.file = f
|
27 |
+
self.header_size, self.header = self._read_header()
|
28 |
+
self.buffer = bytearray()
|
29 |
+
self.buffer_start_offset = 8 + self.header_size
|
30 |
+
self.lock = threading.Lock()
|
31 |
+
|
32 |
+
def __del__(self):
|
33 |
+
self.close()
|
34 |
+
|
35 |
+
def __getitem__(self, key):
|
36 |
+
if key not in self.header or key == "__metadata__":
|
37 |
+
raise KeyError(key)
|
38 |
+
return self._load_tensor(key)
|
39 |
+
|
40 |
+
def __iter__(self):
|
41 |
+
return iter(self.keys())
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.header)
|
45 |
+
|
46 |
+
def close(self):
|
47 |
+
self.file.close()
|
48 |
+
self.buffer = None
|
49 |
+
self.header = None
|
50 |
+
|
51 |
+
def keys(self):
|
52 |
+
return (
|
53 |
+
key
|
54 |
+
for key in self.header.keys()
|
55 |
+
if key != "__metadata__"
|
56 |
+
)
|
57 |
+
|
58 |
+
def values(self):
|
59 |
+
for key in self.keys():
|
60 |
+
yield self[key]
|
61 |
+
|
62 |
+
def items(self):
|
63 |
+
for key in self.keys():
|
64 |
+
yield key, self[key]
|
65 |
+
|
66 |
+
def _read_header(self):
|
67 |
+
header_size_bytes = self.file.read(8)
|
68 |
+
header_size = struct.unpack('<Q', header_size_bytes)[0]
|
69 |
+
header_json = self.file.read(header_size).decode('utf-8').strip()
|
70 |
+
header = json.loads(header_json)
|
71 |
+
|
72 |
+
# sort by memory order to reduce seek time
|
73 |
+
sorted_header = dict(sorted(header.items(), key=lambda item: item[1].get('data_offsets', [0])[0]))
|
74 |
+
return header_size, sorted_header
|
75 |
+
|
76 |
+
def _ensure_buffer(self, start_pos, length):
|
77 |
+
if start_pos < self.buffer_start_offset or start_pos + length > self.buffer_start_offset + len(self.buffer):
|
78 |
+
self.file.seek(start_pos)
|
79 |
+
necessary_buffer_size = max(self.default_buffer_size, length)
|
80 |
+
if len(self.buffer) < necessary_buffer_size:
|
81 |
+
self.buffer = bytearray(necessary_buffer_size)
|
82 |
+
else:
|
83 |
+
self.buffer = self.buffer[:necessary_buffer_size]
|
84 |
+
|
85 |
+
self.file.readinto(self.buffer)
|
86 |
+
self.buffer_start_offset = start_pos
|
87 |
+
|
88 |
+
def _load_tensor(self, tensor_name):
|
89 |
+
tensor_info = self.header[tensor_name]
|
90 |
+
offsets = tensor_info['data_offsets']
|
91 |
+
dtype, dtype_bytes = DTYPE_MAPPING[tensor_info['dtype']]
|
92 |
+
shape = tensor_info['shape']
|
93 |
+
total_bytes = offsets[1] - offsets[0]
|
94 |
+
absolute_start_pos = 8 + self.header_size + offsets[0]
|
95 |
+
with warnings.catch_warnings():
|
96 |
+
warnings.simplefilter('ignore')
|
97 |
+
with self.lock:
|
98 |
+
self._ensure_buffer(absolute_start_pos, total_bytes)
|
99 |
+
buffer_offset = absolute_start_pos - self.buffer_start_offset
|
100 |
+
return torch.frombuffer(self.buffer, count=total_bytes // dtype_bytes, offset=buffer_offset, dtype=dtype).reshape(shape)
|
101 |
+
|