wjs0725 commited on
Commit
4cc901a
1 Parent(s): e1d7eb4

Upload 21 files

Browse files
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from io import BytesIO
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from glob import iglob
8
+ import argparse
9
+ from einops import rearrange
10
+ from fire import Fire
11
+ from PIL import ExifTags, Image
12
+
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import gradio as gr
17
+ import numpy as np
18
+ from transformers import pipeline
19
+
20
+ from flux.sampling import denoise, get_schedule, prepare, unpack
21
+ from flux.util import (configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5)
22
+
23
+ @dataclass
24
+ class SamplingOptions:
25
+ source_prompt: str
26
+ target_prompt: str
27
+ # prompt: str
28
+ width: int
29
+ height: int
30
+ num_steps: int
31
+ guidance: float
32
+ seed: int | None
33
+
34
+ @torch.inference_mode()
35
+ def encode(init_image, torch_device, ae):
36
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
37
+ init_image = init_image.unsqueeze(0)
38
+ init_image = init_image.to(torch_device)
39
+ with torch.no_grad():
40
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
41
+ return init_image
42
+
43
+
44
+ class FluxEditor:
45
+ def __init__(self, args):
46
+ self.args = args
47
+ self.device = torch.device(args.device)
48
+ self.offload = args.offload
49
+ self.name = args.name
50
+ self.is_schnell = args.name == "flux-schnell"
51
+
52
+ self.feature_path = 'feature'
53
+ self.output_dir = 'result'
54
+ self.add_sampling_metadata = True
55
+
56
+ if self.name not in configs:
57
+ available = ", ".join(configs.keys())
58
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
59
+
60
+ # init all components
61
+ self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
62
+ self.clip = load_clip(self.device)
63
+ self.model = load_flow_model(self.name, device="cpu" if self.offload else self.device)
64
+ self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
65
+ self.t5.eval()
66
+ self.clip.eval()
67
+ self.ae.eval()
68
+ self.model.eval()
69
+
70
+ if self.offload:
71
+ self.model.cpu()
72
+ torch.cuda.empty_cache()
73
+ self.ae.encoder.to(self.device)
74
+
75
+ @torch.inference_mode()
76
+ def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
77
+ torch.cuda.empty_cache()
78
+ seed = None
79
+ # if seed == -1:
80
+ # seed = None
81
+
82
+ shape = init_image.shape
83
+
84
+ new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
85
+ new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
86
+
87
+ init_image = init_image[:new_h, :new_w, :]
88
+
89
+ width, height = init_image.shape[0], init_image.shape[1]
90
+ init_image = encode(init_image, self.device, self.ae)
91
+
92
+ print(init_image.shape)
93
+
94
+ rng = torch.Generator(device="cpu")
95
+ opts = SamplingOptions(
96
+ source_prompt=source_prompt,
97
+ target_prompt=target_prompt,
98
+ width=width,
99
+ height=height,
100
+ num_steps=num_steps,
101
+ guidance=guidance,
102
+ seed=seed,
103
+ )
104
+ if opts.seed is None:
105
+ opts.seed = torch.Generator(device="cpu").seed()
106
+
107
+ print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
108
+ t0 = time.perf_counter()
109
+
110
+ opts.seed = None
111
+ if self.offload:
112
+ self.ae = self.ae.cpu()
113
+ torch.cuda.empty_cache()
114
+ self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
115
+
116
+ #############inverse#######################
117
+ info = {}
118
+ info['feature'] = {}
119
+ info['inject_step'] = inject_step
120
+
121
+ if not os.path.exists(self.feature_path):
122
+ os.mkdir(self.feature_path)
123
+
124
+ with torch.no_grad():
125
+ inp = prepare(self.t5, self.clip, init_image, prompt=opts.source_prompt)
126
+ inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
127
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
128
+
129
+ # offload TEs to CPU, load model to gpu
130
+ if self.offload:
131
+ self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
132
+ torch.cuda.empty_cache()
133
+ self.model = self.model.to(self.device)
134
+
135
+ # inversion initial noise
136
+ with torch.no_grad():
137
+ z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
138
+
139
+ inp_target["img"] = z
140
+
141
+ timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
142
+
143
+ # denoise initial noise
144
+ x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
145
+
146
+ # offload model, load autoencoder to gpu
147
+ if self.offload:
148
+ self.model.cpu()
149
+ torch.cuda.empty_cache()
150
+ self.ae.decoder.to(x.device)
151
+
152
+ # decode latents to pixel space
153
+ x = unpack(x.float(), opts.width, opts.height)
154
+
155
+ output_name = os.path.join(self.output_dir, "img_{idx}.jpg")
156
+ if not os.path.exists(self.output_dir):
157
+ os.makedirs(self.output_dir)
158
+ idx = 0
159
+ else:
160
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
161
+ if len(fns) > 0:
162
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
163
+ else:
164
+ idx = 0
165
+
166
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
167
+ x = self.ae.decode(x)
168
+
169
+ if torch.cuda.is_available():
170
+ torch.cuda.synchronize()
171
+ t1 = time.perf_counter()
172
+
173
+ fn = output_name.format(idx=idx)
174
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
175
+ # bring into PIL format and save
176
+ x = x.clamp(-1, 1)
177
+ x = embed_watermark(x.float())
178
+ x = rearrange(x[0], "c h w -> h w c")
179
+
180
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
181
+ exif_data = Image.Exif()
182
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
183
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
184
+ exif_data[ExifTags.Base.Model] = self.name
185
+ if self.add_sampling_metadata:
186
+ exif_data[ExifTags.Base.ImageDescription] = source_prompt
187
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
188
+
189
+
190
+ print("End Edit")
191
+ return img
192
+
193
+
194
+
195
+ def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
196
+ editor = FluxEditor(args)
197
+ is_schnell = model_name == "flux-schnell"
198
+
199
+ with gr.Blocks() as demo:
200
+ gr.Markdown(f"# RF-Edit Demo (FLUX for image editing)")
201
+
202
+ with gr.Row():
203
+ with gr.Column():
204
+ source_prompt = gr.Textbox(label="Source Prompt", value="")
205
+ target_prompt = gr.Textbox(label="Target Prompt", value="")
206
+ init_image = gr.Image(label="Input Image", visible=True)
207
+
208
+
209
+ generate_btn = gr.Button("Generate")
210
+
211
+ with gr.Column():
212
+ with gr.Accordion("Advanced Options", open=True):
213
+ num_steps = gr.Slider(1, 30, 25, step=1, label="Number of steps")
214
+ inject_step = gr.Slider(1, 15, 5, step=1, label="Number of inject steps")
215
+ guidance = gr.Slider(1.0, 10.0, 2, step=0.1, label="Guidance", interactive=not is_schnell)
216
+ # seed = gr.Textbox(0, label="Seed (-1 for random)", visible=False)
217
+ # add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=False)
218
+
219
+ output_image = gr.Image(label="Generated Image")
220
+
221
+ generate_btn.click(
222
+ fn=editor.edit,
223
+ inputs=[init_image, source_prompt, target_prompt, num_steps, inject_step, guidance],
224
+ outputs=[output_image]
225
+ )
226
+
227
+
228
+ return demo
229
+
230
+
231
+ if __name__ == "__main__":
232
+ import argparse
233
+ parser = argparse.ArgumentParser(description="Flux")
234
+ parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
235
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
236
+ parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
237
+ parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
238
+
239
+ parser.add_argument("--port", type=int, default=41035)
240
+ args = parser.parse_args()
241
+
242
+ demo = create_demo(args.name, args.device, args.offload)
243
+ demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
flux/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import version as __version__ # type: ignore
3
+ from ._version import version_tuple
4
+ except ImportError:
5
+ __version__ = "unknown (no version information available)"
6
+ version_tuple = (0, 0, "unknown", "noinfo")
7
+
8
+ from pathlib import Path
9
+
10
+ PACKAGE = __package__.replace("_", "-")
11
+ PACKAGE_ROOT = Path(__file__).parent
flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
flux/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (503 Bytes). View file
 
flux/__pycache__/math.cpython-310.pyc ADDED
Binary file (1.6 kB). View file
 
flux/__pycache__/math.cpython-38.pyc ADDED
Binary file (1.46 kB). View file
 
flux/__pycache__/model.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
flux/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
flux/__pycache__/util.cpython-310.pyc ADDED
Binary file (5.75 kB). View file
 
flux/_version.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.0.post0+d20241105'
16
+ __version_tuple__ = version_tuple = (0, 0, 'd20241105')
flux/api.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_ENDPOINT = "https://api.bfl.ml"
10
+
11
+
12
+ class ApiException(Exception):
13
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14
+ super().__init__()
15
+ self.detail = detail
16
+ self.status_code = status_code
17
+
18
+ def __str__(self) -> str:
19
+ return self.__repr__()
20
+
21
+ def __repr__(self) -> str:
22
+ if self.detail is None:
23
+ message = None
24
+ elif isinstance(self.detail, str):
25
+ message = self.detail
26
+ else:
27
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
+
30
+
31
+ class ImageRequest:
32
+ def __init__(
33
+ self,
34
+ prompt: str,
35
+ width: int = 1024,
36
+ height: int = 1024,
37
+ name: str = "flux.1-pro",
38
+ num_steps: int = 50,
39
+ prompt_upsampling: bool = False,
40
+ seed: int | None = None,
41
+ validate: bool = True,
42
+ launch: bool = True,
43
+ api_key: str | None = None,
44
+ ):
45
+ """
46
+ Manages an image generation request to the API.
47
+
48
+ Args:
49
+ prompt: Prompt to sample
50
+ width: Width of the image in pixel
51
+ height: Height of the image in pixel
52
+ name: Name of the model
53
+ num_steps: Number of network evaluations
54
+ prompt_upsampling: Use prompt upsampling
55
+ seed: Fix the generation seed
56
+ validate: Run input validation
57
+ launch: Directly launches request
58
+ api_key: Your API key if not provided by the environment
59
+
60
+ Raises:
61
+ ValueError: For invalid input
62
+ ApiException: For errors raised from the API
63
+ """
64
+ if validate:
65
+ if name not in ["flux.1-pro"]:
66
+ raise ValueError(f"Invalid model {name}")
67
+ elif width % 32 != 0:
68
+ raise ValueError(f"width must be divisible by 32, got {width}")
69
+ elif not (256 <= width <= 1440):
70
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
71
+ elif height % 32 != 0:
72
+ raise ValueError(f"height must be divisible by 32, got {height}")
73
+ elif not (256 <= height <= 1440):
74
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
75
+ elif not (1 <= num_steps <= 50):
76
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
+
78
+ self.request_json = {
79
+ "prompt": prompt,
80
+ "width": width,
81
+ "height": height,
82
+ "variant": name,
83
+ "steps": num_steps,
84
+ "prompt_upsampling": prompt_upsampling,
85
+ }
86
+ if seed is not None:
87
+ self.request_json["seed"] = seed
88
+
89
+ self.request_id: str | None = None
90
+ self.result: dict | None = None
91
+ self._image_bytes: bytes | None = None
92
+ self._url: str | None = None
93
+ if api_key is None:
94
+ self.api_key = os.environ.get("BFL_API_KEY")
95
+ else:
96
+ self.api_key = api_key
97
+
98
+ if launch:
99
+ self.request()
100
+
101
+ def request(self):
102
+ """
103
+ Request to generate the image.
104
+ """
105
+ if self.request_id is not None:
106
+ return
107
+ response = requests.post(
108
+ f"{API_ENDPOINT}/v1/image",
109
+ headers={
110
+ "accept": "application/json",
111
+ "x-key": self.api_key,
112
+ "Content-Type": "application/json",
113
+ },
114
+ json=self.request_json,
115
+ )
116
+ result = response.json()
117
+ if response.status_code != 200:
118
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
+ self.request_id = response.json()["id"]
120
+
121
+ def retrieve(self) -> dict:
122
+ """
123
+ Wait for the generation to finish and retrieve response.
124
+ """
125
+ if self.request_id is None:
126
+ self.request()
127
+ while self.result is None:
128
+ response = requests.get(
129
+ f"{API_ENDPOINT}/v1/get_result",
130
+ headers={
131
+ "accept": "application/json",
132
+ "x-key": self.api_key,
133
+ },
134
+ params={
135
+ "id": self.request_id,
136
+ },
137
+ )
138
+ result = response.json()
139
+ if "status" not in result:
140
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
+ elif result["status"] == "Ready":
142
+ self.result = result["result"]
143
+ elif result["status"] == "Pending":
144
+ time.sleep(0.5)
145
+ else:
146
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
+ return self.result
148
+
149
+ @property
150
+ def bytes(self) -> bytes:
151
+ """
152
+ Generated image as bytes.
153
+ """
154
+ if self._image_bytes is None:
155
+ response = requests.get(self.url)
156
+ if response.status_code == 200:
157
+ self._image_bytes = response.content
158
+ else:
159
+ raise ApiException(status_code=response.status_code)
160
+ return self._image_bytes
161
+
162
+ @property
163
+ def url(self) -> str:
164
+ """
165
+ Public url to retrieve the image from
166
+ """
167
+ if self._url is None:
168
+ result = self.retrieve()
169
+ self._url = result["sample"]
170
+ return self._url
171
+
172
+ @property
173
+ def image(self) -> Image.Image:
174
+ """
175
+ Load the image as a PIL Image
176
+ """
177
+ return Image.open(io.BytesIO(self.bytes))
178
+
179
+ def save(self, path: str):
180
+ """
181
+ Save the generated image to a local path
182
+ """
183
+ suffix = Path(self.url).suffix
184
+ if not path.endswith(suffix):
185
+ path = path + suffix
186
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
+ with open(path, "wb") as file:
188
+ file.write(self.bytes)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+
194
+ Fire(ImageRequest)
flux/math.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ q, k = apply_rope(q, k, pe)
8
+
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
+ x = rearrange(x, "B H L D -> B L (H D)")
11
+
12
+ return x
13
+
14
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
15
+ assert dim % 2 == 0
16
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
17
+ omega = 1.0 / (theta**scale)
18
+ out = torch.einsum("...n,d->...nd", pos, omega)
19
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
20
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
21
+ return out.float()
22
+
23
+
24
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
25
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
26
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
27
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
28
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
29
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flux/model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7
+ MLPEmbedder, SingleStreamBlock,
8
+ timestep_embedding)
9
+
10
+
11
+ @dataclass
12
+ class FluxParams:
13
+ in_channels: int
14
+ vec_in_dim: int
15
+ context_in_dim: int
16
+ hidden_size: int
17
+ mlp_ratio: float
18
+ num_heads: int
19
+ depth: int
20
+ depth_single_blocks: int
21
+ axes_dim: list[int]
22
+ theta: int
23
+ qkv_bias: bool
24
+ guidance_embed: bool
25
+
26
+
27
+ class Flux(nn.Module):
28
+ """
29
+ Transformer model for flow matching on sequences.
30
+ """
31
+
32
+ def __init__(self, params: FluxParams):
33
+ super().__init__()
34
+
35
+ self.params = params
36
+ self.in_channels = params.in_channels
37
+ self.out_channels = self.in_channels
38
+ if params.hidden_size % params.num_heads != 0:
39
+ raise ValueError(
40
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
41
+ )
42
+ pe_dim = params.hidden_size // params.num_heads
43
+ if sum(params.axes_dim) != pe_dim:
44
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
45
+ self.hidden_size = params.hidden_size
46
+ self.num_heads = params.num_heads
47
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
48
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
49
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
50
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
51
+ self.guidance_in = (
52
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
53
+ )
54
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
55
+
56
+ self.double_blocks = nn.ModuleList(
57
+ [
58
+ DoubleStreamBlock(
59
+ self.hidden_size,
60
+ self.num_heads,
61
+ mlp_ratio=params.mlp_ratio,
62
+ qkv_bias=params.qkv_bias,
63
+ )
64
+ for _ in range(params.depth)
65
+ ]
66
+ )
67
+
68
+ self.single_blocks = nn.ModuleList(
69
+ [
70
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
71
+ for _ in range(params.depth_single_blocks)
72
+ ]
73
+ )
74
+
75
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
76
+
77
+ def forward(
78
+ self,
79
+ img: Tensor,
80
+ img_ids: Tensor,
81
+ txt: Tensor,
82
+ txt_ids: Tensor,
83
+ timesteps: Tensor,
84
+ y: Tensor,
85
+ guidance: Tensor | None = None,
86
+ info = None,
87
+ ) -> Tensor:
88
+ if img.ndim != 3 or txt.ndim != 3:
89
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
90
+
91
+ # running on sequences img
92
+ img = self.img_in(img)
93
+ vec = self.time_in(timestep_embedding(timesteps, 256))
94
+ if self.params.guidance_embed:
95
+ if guidance is None:
96
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
97
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
98
+ vec = vec + self.vector_in(y)
99
+ txt = self.txt_in(txt)
100
+
101
+ ids = torch.cat((txt_ids, img_ids), dim=1)
102
+ pe = self.pe_embedder(ids)
103
+
104
+ for block in self.double_blocks:
105
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, info=info)
106
+
107
+ cnt = 0
108
+ img = torch.cat((txt, img), 1)
109
+ info['type'] = 'single'
110
+ for block in self.single_blocks:
111
+ info['id'] = cnt
112
+ img, info = block(img, vec=vec, pe=pe, info=info)
113
+ cnt += 1
114
+
115
+ img = img[:, txt.shape[1] :, ...]
116
+
117
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
118
+ return img, info
flux/modules/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (9.06 kB). View file
 
flux/modules/__pycache__/conditioner.cpython-310.pyc ADDED
Binary file (1.49 kB). View file
 
flux/modules/__pycache__/layers.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ # import pdb;pdb.set_trace()
271
+ if self.sample:
272
+ std = torch.exp(0.5 * logvar)
273
+ return mean #+ std * torch.randn_like(mean)
274
+ else:
275
+ return mean
276
+
277
+
278
+ class AutoEncoder(nn.Module):
279
+ def __init__(self, params: AutoEncoderParams):
280
+ super().__init__()
281
+ self.encoder = Encoder(
282
+ resolution=params.resolution,
283
+ in_channels=params.in_channels,
284
+ ch=params.ch,
285
+ ch_mult=params.ch_mult,
286
+ num_res_blocks=params.num_res_blocks,
287
+ z_channels=params.z_channels,
288
+ )
289
+ self.decoder = Decoder(
290
+ resolution=params.resolution,
291
+ in_channels=params.in_channels,
292
+ ch=params.ch,
293
+ out_ch=params.out_ch,
294
+ ch_mult=params.ch_mult,
295
+ num_res_blocks=params.num_res_blocks,
296
+ z_channels=params.z_channels,
297
+ )
298
+ self.reg = DiagonalGaussian()
299
+
300
+ self.scale_factor = params.scale_factor
301
+ self.shift_factor = params.shift_factor
302
+
303
+ def encode(self, x: Tensor) -> Tensor:
304
+ z = self.reg(self.encoder(x))
305
+ z = self.scale_factor * (z - self.shift_factor)
306
+ return z
307
+
308
+ def decode(self, z: Tensor) -> Tensor:
309
+ z = z / self.scale_factor + self.shift_factor
310
+ return self.decoder(z)
311
+
312
+ def forward(self, x: Tensor) -> Tensor:
313
+ return self.decode(self.encode(x))
flux/modules/conditioner.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3
+ T5Tokenizer)
4
+
5
+
6
+ class HFEmbedder(nn.Module):
7
+ def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
8
+ super().__init__()
9
+ self.is_clip = is_clip
10
+ self.max_length = max_length
11
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
+
13
+ if self.is_clip:
14
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
16
+ else:
17
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
19
+
20
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
21
+
22
+ def forward(self, text: list[str]) -> Tensor:
23
+ batch_encoding = self.tokenizer(
24
+ text,
25
+ truncation=True,
26
+ max_length=self.max_length,
27
+ return_length=False,
28
+ return_overflowing_tokens=False,
29
+ padding="max_length",
30
+ return_tensors="pt",
31
+ )
32
+
33
+ outputs = self.hf_module(
34
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
35
+ attention_mask=None,
36
+ output_hidden_states=False,
37
+ )
38
+ return outputs[self.output_key]
flux/modules/layers.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope
9
+
10
+ import os
11
+
12
+ class EmbedND(nn.Module):
13
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.theta = theta
17
+ self.axes_dim = axes_dim
18
+
19
+ def forward(self, ids: Tensor) -> Tensor:
20
+ n_axes = ids.shape[-1]
21
+ emb = torch.cat(
22
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
23
+ dim=-3,
24
+ )
25
+
26
+ return emb.unsqueeze(1)
27
+
28
+
29
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30
+ """
31
+ Create sinusoidal timestep embeddings.
32
+ :param t: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param dim: the dimension of the output.
35
+ :param max_period: controls the minimum frequency of the embeddings.
36
+ :return: an (N, D) Tensor of positional embeddings.
37
+ """
38
+ t = time_factor * t
39
+ half = dim // 2
40
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
41
+ t.device
42
+ )
43
+
44
+ args = t[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
+ if torch.is_floating_point(t):
49
+ embedding = embedding.to(t)
50
+ return embedding
51
+
52
+
53
+ class MLPEmbedder(nn.Module):
54
+ def __init__(self, in_dim: int, hidden_dim: int):
55
+ super().__init__()
56
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
57
+ self.silu = nn.SiLU()
58
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
59
+
60
+ def forward(self, x: Tensor) -> Tensor:
61
+ return self.out_layer(self.silu(self.in_layer(x)))
62
+
63
+
64
+ class RMSNorm(torch.nn.Module):
65
+ def __init__(self, dim: int):
66
+ super().__init__()
67
+ self.scale = nn.Parameter(torch.ones(dim))
68
+
69
+ def forward(self, x: Tensor):
70
+ x_dtype = x.dtype
71
+ x = x.float()
72
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
73
+ return (x * rrms).to(dtype=x_dtype) * self.scale
74
+
75
+
76
+ class QKNorm(torch.nn.Module):
77
+ def __init__(self, dim: int):
78
+ super().__init__()
79
+ self.query_norm = RMSNorm(dim)
80
+ self.key_norm = RMSNorm(dim)
81
+
82
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
83
+ q = self.query_norm(q)
84
+ k = self.key_norm(k)
85
+ return q.to(v), k.to(v)
86
+
87
+
88
+ class SelfAttention(nn.Module):
89
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
90
+ super().__init__()
91
+ self.num_heads = num_heads
92
+ head_dim = dim // num_heads
93
+
94
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95
+ self.norm = QKNorm(head_dim)
96
+ self.proj = nn.Linear(dim, dim)
97
+
98
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
99
+ qkv = self.qkv(x)
100
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
101
+ q, k = self.norm(q, k, v)
102
+ x = attention(q, k, v, pe=pe)
103
+ x = self.proj(x)
104
+ return x
105
+
106
+
107
+ @dataclass
108
+ class ModulationOut:
109
+ shift: Tensor
110
+ scale: Tensor
111
+ gate: Tensor
112
+
113
+
114
+ class Modulation(nn.Module):
115
+ def __init__(self, dim: int, double: bool):
116
+ super().__init__()
117
+ self.is_double = double
118
+ self.multiplier = 6 if double else 3
119
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
120
+
121
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123
+
124
+ return (
125
+ ModulationOut(*out[:3]),
126
+ ModulationOut(*out[3:]) if self.is_double else None,
127
+ )
128
+
129
+
130
+ class DoubleStreamBlock(nn.Module):
131
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
132
+ super().__init__()
133
+
134
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
135
+ self.num_heads = num_heads
136
+ self.hidden_size = hidden_size
137
+ self.img_mod = Modulation(hidden_size, double=True)
138
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
139
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
140
+
141
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
+ self.img_mlp = nn.Sequential(
143
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
144
+ nn.GELU(approximate="tanh"),
145
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
146
+ )
147
+
148
+ self.txt_mod = Modulation(hidden_size, double=True)
149
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
150
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
151
+
152
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
+ self.txt_mlp = nn.Sequential(
154
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
155
+ nn.GELU(approximate="tanh"),
156
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
157
+ )
158
+
159
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, info) -> tuple[Tensor, Tensor]:
160
+ img_mod1, img_mod2 = self.img_mod(vec)
161
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
162
+
163
+ # prepare image for attention
164
+ img_modulated = self.img_norm1(img)
165
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
166
+ img_qkv = self.img_attn.qkv(img_modulated)
167
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
168
+
169
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170
+
171
+ # if info['inject']:
172
+ # if info['inverse']:
173
+ # print("!save! ",info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'])
174
+ # torch.save(img_q, info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'Q' + '.pth')
175
+ # if not info['inverse']:
176
+ # print("!load! ", info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'])
177
+ # img_q = torch.load(info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'Q' + '.pth', weights_only=True)
178
+
179
+ # prepare txt for attention
180
+ txt_modulated = self.txt_norm1(txt)
181
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
182
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
183
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
184
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
185
+
186
+ # run actual attention
187
+ q = torch.cat((txt_q, img_q), dim=2) #[8, 24, 512, 128] + [8, 24, 900, 128] -> [8, 24, 1412, 128]
188
+ k = torch.cat((txt_k, img_k), dim=2)
189
+ v = torch.cat((txt_v, img_v), dim=2)
190
+ # import pdb;pdb.set_trace()
191
+ attn = attention(q, k, v, pe=pe)
192
+
193
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
194
+
195
+ # calculate the img bloks
196
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
197
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
198
+
199
+ # calculate the txt bloks
200
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
201
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
202
+ return img, txt
203
+
204
+
205
+ class SingleStreamBlock(nn.Module):
206
+ """
207
+ A DiT block with parallel linear layers as described in
208
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ hidden_size: int,
214
+ num_heads: int,
215
+ mlp_ratio: float = 4.0,
216
+ qk_scale: float | None = None,
217
+ ):
218
+ super().__init__()
219
+ self.hidden_dim = hidden_size
220
+ self.num_heads = num_heads
221
+ head_dim = hidden_size // num_heads
222
+ self.scale = qk_scale or head_dim**-0.5
223
+
224
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
225
+ # qkv and mlp_in
226
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
227
+ # proj and mlp_out
228
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
229
+
230
+ self.norm = QKNorm(head_dim)
231
+
232
+ self.hidden_size = hidden_size
233
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
234
+
235
+ self.mlp_act = nn.GELU(approximate="tanh")
236
+ self.modulation = Modulation(hidden_size, double=False)
237
+
238
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor:
239
+ mod, _ = self.modulation(vec)
240
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
241
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
242
+
243
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
244
+ q, k = self.norm(q, k, v)
245
+
246
+ # Note: If the memory of your device is not enough, you may consider uncomment the following code.
247
+ # if info['inject'] and info['id'] > 19:
248
+ # store_path = os.path.join(info['feature_path'], str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V' + '.pth')
249
+ # if info['inverse']:
250
+ # torch.save(v, store_path)
251
+ # if not info['inverse']:
252
+ # v = torch.load(store_path, weights_only=True)
253
+
254
+ # Save the features in the memory
255
+ if info['inject'] and info['id'] > 19:
256
+ feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V'
257
+ if info['inverse']:
258
+ info['feature'][feature_name] = v.cpu()
259
+ else:
260
+ v = info['feature'][feature_name].cuda()
261
+
262
+ # compute attention
263
+ attn = attention(q, k, v, pe=pe)
264
+ # compute activation in mlp stream, cat again and run second linear layer
265
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
266
+ return x + mod.gate * output, info
267
+
268
+
269
+ class LastLayer(nn.Module):
270
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
271
+ super().__init__()
272
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
273
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
274
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
275
+
276
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
277
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
278
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
279
+ x = self.linear(x)
280
+ return x
flux/sampling.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor
7
+
8
+ from .model import Flux
9
+ from .modules.conditioner import HFEmbedder
10
+
11
+
12
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
13
+ bs, c, h, w = img.shape
14
+ if bs == 1 and not isinstance(prompt, str):
15
+ bs = len(prompt)
16
+
17
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
18
+ if img.shape[0] == 1 and bs > 1:
19
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
20
+
21
+ img_ids = torch.zeros(h // 2, w // 2, 3)
22
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
23
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
24
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
25
+
26
+ if isinstance(prompt, str):
27
+ prompt = [prompt]
28
+ txt = t5(prompt)
29
+ if txt.shape[0] == 1 and bs > 1:
30
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
31
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
32
+
33
+ vec = clip(prompt)
34
+ if vec.shape[0] == 1 and bs > 1:
35
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
36
+
37
+ return {
38
+ "img": img,
39
+ "img_ids": img_ids.to(img.device),
40
+ "txt": txt.to(img.device),
41
+ "txt_ids": txt_ids.to(img.device),
42
+ "vec": vec.to(img.device),
43
+ }
44
+
45
+
46
+ def time_shift(mu: float, sigma: float, t: Tensor):
47
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
48
+
49
+
50
+ def get_lin_function(
51
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
52
+ ) -> Callable[[float], float]:
53
+ m = (y2 - y1) / (x2 - x1)
54
+ b = y1 - m * x1
55
+ return lambda x: m * x + b
56
+
57
+
58
+ def get_schedule(
59
+ num_steps: int,
60
+ image_seq_len: int,
61
+ base_shift: float = 0.5,
62
+ max_shift: float = 1.15,
63
+ shift: bool = True,
64
+ ) -> list[float]:
65
+ # extra step for zero
66
+ timesteps = torch.linspace(1, 0, num_steps + 1)
67
+
68
+ # shifting the schedule to favor high timesteps for higher signal images
69
+ if shift:
70
+ # estimate mu based on linear estimation between two points
71
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
72
+ timesteps = time_shift(mu, 1.0, timesteps)
73
+
74
+ return timesteps.tolist()
75
+
76
+
77
+ def denoise(
78
+ model: Flux,
79
+ # model input
80
+ img: Tensor,
81
+ img_ids: Tensor,
82
+ txt: Tensor,
83
+ txt_ids: Tensor,
84
+ vec: Tensor,
85
+ # sampling parameters
86
+ timesteps: list[float],
87
+ inverse,
88
+ info,
89
+ guidance: float = 4.0
90
+ ):
91
+ # this is ignored for schnell
92
+ inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
93
+
94
+ if inverse:
95
+ timesteps = timesteps[::-1]
96
+ inject_list = inject_list[::-1]
97
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
98
+
99
+ step_list = []
100
+ for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
101
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
102
+ info['t'] = t_prev if inverse else t_curr
103
+ info['inverse'] = inverse
104
+ info['second_order'] = False
105
+ info['inject'] = inject_list[i]
106
+
107
+ pred, info = model(
108
+ img=img,
109
+ img_ids=img_ids,
110
+ txt=txt,
111
+ txt_ids=txt_ids,
112
+ y=vec,
113
+ timesteps=t_vec,
114
+ guidance=guidance_vec,
115
+ info=info
116
+ )
117
+
118
+ img_mid = img + (t_prev - t_curr) / 2 * pred
119
+
120
+ t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
121
+ info['second_order'] = True
122
+ pred_mid, info = model(
123
+ img=img_mid,
124
+ img_ids=img_ids,
125
+ txt=txt,
126
+ txt_ids=txt_ids,
127
+ y=vec,
128
+ timesteps=t_vec_mid,
129
+ guidance=guidance_vec,
130
+ info=info
131
+ )
132
+
133
+ first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
134
+ img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
135
+
136
+ return img, info
137
+
138
+
139
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
140
+ return rearrange(
141
+ x,
142
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
143
+ h=math.ceil(height / 16),
144
+ w=math.ceil(width / 16),
145
+ ph=2,
146
+ pw=2,
147
+ )
flux/util.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from imwatermark import WatermarkEncoder
8
+ from safetensors.torch import load_file as load_sft
9
+
10
+ from flux.model import Flux, FluxParams
11
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
12
+ from flux.modules.conditioner import HFEmbedder
13
+
14
+
15
+ @dataclass
16
+ class ModelSpec:
17
+ params: FluxParams
18
+ ae_params: AutoEncoderParams
19
+ ckpt_path: str | None
20
+ ae_path: str | None
21
+ repo_id: str | None
22
+ repo_flow: str | None
23
+ repo_ae: str | None
24
+
25
+ configs = {
26
+ "flux-dev": ModelSpec(
27
+ repo_id="black-forest-labs/FLUX.1-dev",
28
+ repo_flow="flux1-dev.safetensors",
29
+ repo_ae="ae.safetensors",
30
+ ckpt_path=os.getenv("FLUX_DEV"),
31
+ params=FluxParams(
32
+ in_channels=64,
33
+ vec_in_dim=768,
34
+ context_in_dim=4096,
35
+ hidden_size=3072,
36
+ mlp_ratio=4.0,
37
+ num_heads=24,
38
+ depth=19,
39
+ depth_single_blocks=38,
40
+ axes_dim=[16, 56, 56],
41
+ theta=10_000,
42
+ qkv_bias=True,
43
+ guidance_embed=True,
44
+ ),
45
+ ae_path=os.getenv("AE"),
46
+ ae_params=AutoEncoderParams(
47
+ resolution=256,
48
+ in_channels=3,
49
+ ch=128,
50
+ out_ch=3,
51
+ ch_mult=[1, 2, 4, 4],
52
+ num_res_blocks=2,
53
+ z_channels=16,
54
+ scale_factor=0.3611,
55
+ shift_factor=0.1159,
56
+ ),
57
+ ),
58
+ "flux-schnell": ModelSpec(
59
+ repo_id="black-forest-labs/FLUX.1-schnell",
60
+ repo_flow="flux1-schnell.safetensors",
61
+ repo_ae="ae.safetensors",
62
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
63
+ params=FluxParams(
64
+ in_channels=64,
65
+ vec_in_dim=768,
66
+ context_in_dim=4096,
67
+ hidden_size=3072,
68
+ mlp_ratio=4.0,
69
+ num_heads=24,
70
+ depth=19,
71
+ depth_single_blocks=38,
72
+ axes_dim=[16, 56, 56],
73
+ theta=10_000,
74
+ qkv_bias=True,
75
+ guidance_embed=False,
76
+ ),
77
+ ae_path=os.getenv("AE"),
78
+ ae_params=AutoEncoderParams(
79
+ resolution=256,
80
+ in_channels=3,
81
+ ch=128,
82
+ out_ch=3,
83
+ ch_mult=[1, 2, 4, 4],
84
+ num_res_blocks=2,
85
+ z_channels=16,
86
+ scale_factor=0.3611,
87
+ shift_factor=0.1159,
88
+ ),
89
+ ),
90
+ }
91
+
92
+
93
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
94
+ if len(missing) > 0 and len(unexpected) > 0:
95
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
96
+ print("\n" + "-" * 79 + "\n")
97
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
98
+ elif len(missing) > 0:
99
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
100
+ elif len(unexpected) > 0:
101
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
102
+
103
+
104
+ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
105
+ # Loading Flux
106
+ print("Init model")
107
+
108
+ ckpt_path = configs[name].ckpt_path
109
+ if (
110
+ ckpt_path is None
111
+ and configs[name].repo_id is not None
112
+ and configs[name].repo_flow is not None
113
+ and hf_download
114
+ ):
115
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
116
+
117
+ with torch.device("meta" if ckpt_path is not None else device):
118
+ model = Flux(configs[name].params).to(torch.bfloat16)
119
+
120
+ if ckpt_path is not None:
121
+ print("Loading checkpoint")
122
+ # load_sft doesn't support torch.device
123
+ sd = load_sft(ckpt_path, device=str(device))
124
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
125
+ print_load_warning(missing, unexpected)
126
+ return model
127
+
128
+
129
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
130
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
132
+
133
+
134
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
135
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
136
+
137
+
138
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
139
+ ckpt_path = configs[name].ae_path
140
+ if (
141
+ ckpt_path is None
142
+ and configs[name].repo_id is not None
143
+ and configs[name].repo_ae is not None
144
+ and hf_download
145
+ ):
146
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
147
+
148
+ # Loading the autoencoder
149
+ print("Init AE")
150
+ with torch.device("meta" if ckpt_path is not None else device):
151
+ ae = AutoEncoder(configs[name].ae_params)
152
+
153
+ if ckpt_path is not None:
154
+ sd = load_sft(ckpt_path, device=str(device))
155
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
156
+ print_load_warning(missing, unexpected)
157
+ return ae
158
+
159
+
160
+ class WatermarkEmbedder:
161
+ def __init__(self, watermark):
162
+ self.watermark = watermark
163
+ self.num_bits = len(WATERMARK_BITS)
164
+ self.encoder = WatermarkEncoder()
165
+ self.encoder.set_watermark("bits", self.watermark)
166
+
167
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Adds a predefined watermark to the input image
170
+
171
+ Args:
172
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
173
+
174
+ Returns:
175
+ same as input but watermarked
176
+ """
177
+ image = 0.5 * image + 0.5
178
+ squeeze = len(image.shape) == 4
179
+ if squeeze:
180
+ image = image[None, ...]
181
+ n = image.shape[0]
182
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
183
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
184
+ # watermarking libary expects input as cv2 BGR format
185
+ for k in range(image_np.shape[0]):
186
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
187
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
188
+ image.device
189
+ )
190
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
191
+ if squeeze:
192
+ image = image[0]
193
+ image = 2 * image - 1
194
+ return image
195
+
196
+
197
+ # A fixed 48-bit message that was chosen at random
198
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
199
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
200
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
201
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)