wjs0725 commited on
Commit
e1d7eb4
1 Parent(s): 0d25537

Delete src

Browse files
src/edit.py DELETED
@@ -1,248 +0,0 @@
1
- import os
2
- import re
3
- import time
4
- from dataclasses import dataclass
5
- from glob import iglob
6
- import argparse
7
- import torch
8
- from einops import rearrange
9
- from fire import Fire
10
- from PIL import ExifTags, Image
11
-
12
- from flux.sampling import denoise, get_schedule, prepare, unpack
13
- from flux.util import (configs, embed_watermark, load_ae, load_clip,
14
- load_flow_model, load_t5)
15
- from transformers import pipeline
16
- from PIL import Image
17
- import numpy as np
18
-
19
- import os
20
-
21
- NSFW_THRESHOLD = 0.85
22
-
23
- @dataclass
24
- class SamplingOptions:
25
- source_prompt: str
26
- target_prompt: str
27
- # prompt: str
28
- width: int
29
- height: int
30
- num_steps: int
31
- guidance: float
32
- seed: int | None
33
-
34
- @torch.inference_mode()
35
- def encode(init_image, torch_device, ae):
36
- init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
37
- init_image = init_image.unsqueeze(0)
38
- init_image = init_image.to(torch_device)
39
- init_image = ae.encode(init_image.to()).to(torch.bfloat16)
40
- return init_image
41
-
42
- @torch.inference_mode()
43
- def main(
44
- args,
45
- seed: int | None = None,
46
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
47
- num_steps: int | None = None,
48
- loop: bool = False,
49
- offload: bool = False,
50
- add_sampling_metadata: bool = True,
51
- ):
52
- """
53
- Sample the flux model. Either interactively (set `--loop`) or run for a
54
- single image.
55
-
56
- Args:
57
- name: Name of the model to load
58
- height: height of the sample in pixels (should be a multiple of 16)
59
- width: width of the sample in pixels (should be a multiple of 16)
60
- seed: Set a seed for sampling
61
- output_name: where to save the output image, `{idx}` will be replaced
62
- by the index of the sample
63
- prompt: Prompt used for sampling
64
- device: Pytorch device
65
- num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
66
- loop: start an interactive session and sample multiple times
67
- guidance: guidance value used for guidance distillation
68
- add_sampling_metadata: Add the prompt to the image Exif metadata
69
- """
70
- torch.set_grad_enabled(False)
71
- name = args.name
72
- source_prompt = args.source_prompt
73
- target_prompt = args.target_prompt
74
- guidance = args.guidance
75
- output_dir = args.output_dir
76
- num_steps = args.num_steps
77
- offload = args.offload
78
-
79
- nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
80
-
81
- if name not in configs:
82
- available = ", ".join(configs.keys())
83
- raise ValueError(f"Got unknown model name: {name}, chose from {available}")
84
-
85
- torch_device = torch.device(device)
86
- if num_steps is None:
87
- num_steps = 4 if name == "flux-schnell" else 25
88
-
89
- # init all components
90
- t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
91
- clip = load_clip(torch_device)
92
- model = load_flow_model(name, device="cpu" if offload else torch_device)
93
- ae = load_ae(name, device="cpu" if offload else torch_device)
94
-
95
- if offload:
96
- model.cpu()
97
- torch.cuda.empty_cache()
98
- ae.encoder.to(torch_device)
99
-
100
- init_image = None
101
- init_image = np.array(Image.open(args.source_img_dir).convert('RGB'))
102
-
103
- shape = init_image.shape
104
-
105
- new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
106
- new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
107
-
108
- init_image = init_image[:new_h, :new_w, :]
109
-
110
- width, height = init_image.shape[0], init_image.shape[1]
111
- init_image = encode(init_image, torch_device, ae)
112
-
113
- rng = torch.Generator(device="cpu")
114
- opts = SamplingOptions(
115
- source_prompt=source_prompt,
116
- target_prompt=target_prompt,
117
- width=width,
118
- height=height,
119
- num_steps=num_steps,
120
- guidance=guidance,
121
- seed=seed,
122
- )
123
-
124
- if loop:
125
- opts = parse_prompt(opts)
126
-
127
- while opts is not None:
128
- if opts.seed is None:
129
- opts.seed = rng.seed()
130
- print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
131
- t0 = time.perf_counter()
132
-
133
- opts.seed = None
134
- if offload:
135
- ae = ae.cpu()
136
- torch.cuda.empty_cache()
137
- t5, clip = t5.to(torch_device), clip.to(torch_device)
138
-
139
- info = {}
140
- info['feature_path'] = args.feature_path
141
- info['feature'] = {}
142
- info['inject_step'] = args.inject
143
- if not os.path.exists(args.feature_path):
144
- os.mkdir(args.feature_path)
145
-
146
- inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
147
- inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
148
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
149
-
150
- # offload TEs to CPU, load model to gpu
151
- if offload:
152
- t5, clip = t5.cpu(), clip.cpu()
153
- torch.cuda.empty_cache()
154
- model = model.to(torch_device)
155
-
156
- # inversion initial noise
157
- z, info = denoise(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
158
-
159
- inp_target["img"] = z
160
-
161
- timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(name != "flux-schnell"))
162
-
163
- # denoise initial noise
164
- x, _ = denoise(model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
165
-
166
- if offload:
167
- model.cpu()
168
- torch.cuda.empty_cache()
169
- ae.decoder.to(x.device)
170
-
171
- # decode latents to pixel space
172
- batch_x = unpack(x.float(), opts.width, opts.height)
173
-
174
- for x in batch_x:
175
- x = x.unsqueeze(0)
176
- output_name = os.path.join(output_dir, "img_{idx}.jpg")
177
- if not os.path.exists(output_dir):
178
- os.makedirs(output_dir)
179
- idx = 0
180
- else:
181
- fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
182
- if len(fns) > 0:
183
- idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
184
- else:
185
- idx = 0
186
-
187
- with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
188
- x = ae.decode(x)
189
-
190
- if torch.cuda.is_available():
191
- torch.cuda.synchronize()
192
- t1 = time.perf_counter()
193
-
194
- fn = output_name.format(idx=idx)
195
- print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
196
- # bring into PIL format and save
197
- x = x.clamp(-1, 1)
198
- x = embed_watermark(x.float())
199
- x = rearrange(x[0], "c h w -> h w c")
200
-
201
- img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
202
- nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
203
-
204
- if nsfw_score < NSFW_THRESHOLD:
205
- exif_data = Image.Exif()
206
- exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
207
- exif_data[ExifTags.Base.Make] = "Black Forest Labs"
208
- exif_data[ExifTags.Base.Model] = name
209
- if add_sampling_metadata:
210
- exif_data[ExifTags.Base.ImageDescription] = source_prompt
211
- img.save(fn, exif=exif_data, quality=95, subsampling=0)
212
- idx += 1
213
- else:
214
- print("Your generated image may contain NSFW content.")
215
-
216
- if loop:
217
- print("-" * 80)
218
- opts = parse_prompt(opts)
219
- else:
220
- opts = None
221
-
222
- if __name__ == "__main__":
223
-
224
- parser = argparse.ArgumentParser(description='RF-Edit')
225
-
226
- parser.add_argument('--name', default='flux-dev', type=str,
227
- help='flux model')
228
- parser.add_argument('--source_img_dir', default='', type=str,
229
- help='The path of the source image')
230
- parser.add_argument('--source_prompt', type=str,
231
- help='describe the content of the source image (or leaves it as null)')
232
- parser.add_argument('--target_prompt', type=str,
233
- help='describe the requirement of editing')
234
- parser.add_argument('--feature_path', type=str, default='feature',
235
- help='the path to save the feature ')
236
- parser.add_argument('--guidance', type=float, default=5,
237
- help='guidance scale')
238
- parser.add_argument('--num_steps', type=int, default=25,
239
- help='the number of timesteps for inversion and denoising')
240
- parser.add_argument('--inject', type=int, default=20,
241
- help='the number of timesteps which apply the feature sharing')
242
- parser.add_argument('--output_dir', default='output', type=str,
243
- help='the path of the edited image')
244
- parser.add_argument('--offload', action='store_true', help='set it to True if the memory of GPU is not enough')
245
-
246
- args = parser.parse_args()
247
-
248
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/examples/edit/boy.jpg DELETED
Binary file (363 kB)
 
src/examples/edit/hiking.jpg DELETED
Binary file (335 kB)
 
src/examples/edit/horse.jpg DELETED
Binary file (391 kB)
 
src/examples/source/art.jpg DELETED

Git LFS Details

  • SHA256: 3d4c7daf7d513265fe95efa65e4f4511e893bbd85c7ed30034548827fd6f5acc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
src/examples/source/boy.jpg DELETED
Binary file (337 kB)
 
src/examples/source/cartoon.jpg DELETED
Binary file (399 kB)
 
src/examples/source/hiking.jpg DELETED
Binary file (286 kB)
 
src/examples/source/horse.jpg DELETED
Binary file (381 kB)
 
src/examples/source/nobel.jpg DELETED
Binary file (674 kB)
 
src/flux/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- try:
2
- from ._version import version as __version__ # type: ignore
3
- from ._version import version_tuple
4
- except ImportError:
5
- __version__ = "unknown (no version information available)"
6
- version_tuple = (0, 0, "unknown", "noinfo")
7
-
8
- from pathlib import Path
9
-
10
- PACKAGE = __package__.replace("_", "-")
11
- PACKAGE_ROOT = Path(__file__).parent
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/__main__.py DELETED
@@ -1,4 +0,0 @@
1
- from .cli import app
2
-
3
- if __name__ == "__main__":
4
- app()
 
 
 
 
 
src/flux/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (503 Bytes)
 
src/flux/__pycache__/math.cpython-310.pyc DELETED
Binary file (1.6 kB)
 
src/flux/__pycache__/math.cpython-38.pyc DELETED
Binary file (1.46 kB)
 
src/flux/__pycache__/model.cpython-310.pyc DELETED
Binary file (3.46 kB)
 
src/flux/__pycache__/sampling.cpython-310.pyc DELETED
Binary file (3.68 kB)
 
src/flux/__pycache__/util.cpython-310.pyc DELETED
Binary file (5.75 kB)
 
src/flux/_version.py DELETED
@@ -1,16 +0,0 @@
1
- # file generated by setuptools_scm
2
- # don't change, don't track in version control
3
- TYPE_CHECKING = False
4
- if TYPE_CHECKING:
5
- from typing import Tuple, Union
6
- VERSION_TUPLE = Tuple[Union[int, str], ...]
7
- else:
8
- VERSION_TUPLE = object
9
-
10
- version: str
11
- __version__: str
12
- __version_tuple__: VERSION_TUPLE
13
- version_tuple: VERSION_TUPLE
14
-
15
- __version__ = version = '0.0.post0+d20241105'
16
- __version_tuple__ = version_tuple = (0, 0, 'd20241105')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/api.py DELETED
@@ -1,194 +0,0 @@
1
- import io
2
- import os
3
- import time
4
- from pathlib import Path
5
-
6
- import requests
7
- from PIL import Image
8
-
9
- API_ENDPOINT = "https://api.bfl.ml"
10
-
11
-
12
- class ApiException(Exception):
13
- def __init__(self, status_code: int, detail: str | list[dict] | None = None):
14
- super().__init__()
15
- self.detail = detail
16
- self.status_code = status_code
17
-
18
- def __str__(self) -> str:
19
- return self.__repr__()
20
-
21
- def __repr__(self) -> str:
22
- if self.detail is None:
23
- message = None
24
- elif isinstance(self.detail, str):
25
- message = self.detail
26
- else:
27
- message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
28
- return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
29
-
30
-
31
- class ImageRequest:
32
- def __init__(
33
- self,
34
- prompt: str,
35
- width: int = 1024,
36
- height: int = 1024,
37
- name: str = "flux.1-pro",
38
- num_steps: int = 50,
39
- prompt_upsampling: bool = False,
40
- seed: int | None = None,
41
- validate: bool = True,
42
- launch: bool = True,
43
- api_key: str | None = None,
44
- ):
45
- """
46
- Manages an image generation request to the API.
47
-
48
- Args:
49
- prompt: Prompt to sample
50
- width: Width of the image in pixel
51
- height: Height of the image in pixel
52
- name: Name of the model
53
- num_steps: Number of network evaluations
54
- prompt_upsampling: Use prompt upsampling
55
- seed: Fix the generation seed
56
- validate: Run input validation
57
- launch: Directly launches request
58
- api_key: Your API key if not provided by the environment
59
-
60
- Raises:
61
- ValueError: For invalid input
62
- ApiException: For errors raised from the API
63
- """
64
- if validate:
65
- if name not in ["flux.1-pro"]:
66
- raise ValueError(f"Invalid model {name}")
67
- elif width % 32 != 0:
68
- raise ValueError(f"width must be divisible by 32, got {width}")
69
- elif not (256 <= width <= 1440):
70
- raise ValueError(f"width must be between 256 and 1440, got {width}")
71
- elif height % 32 != 0:
72
- raise ValueError(f"height must be divisible by 32, got {height}")
73
- elif not (256 <= height <= 1440):
74
- raise ValueError(f"height must be between 256 and 1440, got {height}")
75
- elif not (1 <= num_steps <= 50):
76
- raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
77
-
78
- self.request_json = {
79
- "prompt": prompt,
80
- "width": width,
81
- "height": height,
82
- "variant": name,
83
- "steps": num_steps,
84
- "prompt_upsampling": prompt_upsampling,
85
- }
86
- if seed is not None:
87
- self.request_json["seed"] = seed
88
-
89
- self.request_id: str | None = None
90
- self.result: dict | None = None
91
- self._image_bytes: bytes | None = None
92
- self._url: str | None = None
93
- if api_key is None:
94
- self.api_key = os.environ.get("BFL_API_KEY")
95
- else:
96
- self.api_key = api_key
97
-
98
- if launch:
99
- self.request()
100
-
101
- def request(self):
102
- """
103
- Request to generate the image.
104
- """
105
- if self.request_id is not None:
106
- return
107
- response = requests.post(
108
- f"{API_ENDPOINT}/v1/image",
109
- headers={
110
- "accept": "application/json",
111
- "x-key": self.api_key,
112
- "Content-Type": "application/json",
113
- },
114
- json=self.request_json,
115
- )
116
- result = response.json()
117
- if response.status_code != 200:
118
- raise ApiException(status_code=response.status_code, detail=result.get("detail"))
119
- self.request_id = response.json()["id"]
120
-
121
- def retrieve(self) -> dict:
122
- """
123
- Wait for the generation to finish and retrieve response.
124
- """
125
- if self.request_id is None:
126
- self.request()
127
- while self.result is None:
128
- response = requests.get(
129
- f"{API_ENDPOINT}/v1/get_result",
130
- headers={
131
- "accept": "application/json",
132
- "x-key": self.api_key,
133
- },
134
- params={
135
- "id": self.request_id,
136
- },
137
- )
138
- result = response.json()
139
- if "status" not in result:
140
- raise ApiException(status_code=response.status_code, detail=result.get("detail"))
141
- elif result["status"] == "Ready":
142
- self.result = result["result"]
143
- elif result["status"] == "Pending":
144
- time.sleep(0.5)
145
- else:
146
- raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
147
- return self.result
148
-
149
- @property
150
- def bytes(self) -> bytes:
151
- """
152
- Generated image as bytes.
153
- """
154
- if self._image_bytes is None:
155
- response = requests.get(self.url)
156
- if response.status_code == 200:
157
- self._image_bytes = response.content
158
- else:
159
- raise ApiException(status_code=response.status_code)
160
- return self._image_bytes
161
-
162
- @property
163
- def url(self) -> str:
164
- """
165
- Public url to retrieve the image from
166
- """
167
- if self._url is None:
168
- result = self.retrieve()
169
- self._url = result["sample"]
170
- return self._url
171
-
172
- @property
173
- def image(self) -> Image.Image:
174
- """
175
- Load the image as a PIL Image
176
- """
177
- return Image.open(io.BytesIO(self.bytes))
178
-
179
- def save(self, path: str):
180
- """
181
- Save the generated image to a local path
182
- """
183
- suffix = Path(self.url).suffix
184
- if not path.endswith(suffix):
185
- path = path + suffix
186
- Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
187
- with open(path, "wb") as file:
188
- file.write(self.bytes)
189
-
190
-
191
- if __name__ == "__main__":
192
- from fire import Fire
193
-
194
- Fire(ImageRequest)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/math.py DELETED
@@ -1,29 +0,0 @@
1
- import torch
2
- from einops import rearrange
3
- from torch import Tensor
4
-
5
-
6
- def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
- q, k = apply_rope(q, k, pe)
8
-
9
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
- x = rearrange(x, "B H L D -> B L (H D)")
11
-
12
- return x
13
-
14
- def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
15
- assert dim % 2 == 0
16
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
17
- omega = 1.0 / (theta**scale)
18
- out = torch.einsum("...n,d->...nd", pos, omega)
19
- out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
20
- out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
21
- return out.float()
22
-
23
-
24
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
25
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
26
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
27
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
28
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
29
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/model.py DELETED
@@ -1,118 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- from torch import Tensor, nn
5
-
6
- from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7
- MLPEmbedder, SingleStreamBlock,
8
- timestep_embedding)
9
-
10
-
11
- @dataclass
12
- class FluxParams:
13
- in_channels: int
14
- vec_in_dim: int
15
- context_in_dim: int
16
- hidden_size: int
17
- mlp_ratio: float
18
- num_heads: int
19
- depth: int
20
- depth_single_blocks: int
21
- axes_dim: list[int]
22
- theta: int
23
- qkv_bias: bool
24
- guidance_embed: bool
25
-
26
-
27
- class Flux(nn.Module):
28
- """
29
- Transformer model for flow matching on sequences.
30
- """
31
-
32
- def __init__(self, params: FluxParams):
33
- super().__init__()
34
-
35
- self.params = params
36
- self.in_channels = params.in_channels
37
- self.out_channels = self.in_channels
38
- if params.hidden_size % params.num_heads != 0:
39
- raise ValueError(
40
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
41
- )
42
- pe_dim = params.hidden_size // params.num_heads
43
- if sum(params.axes_dim) != pe_dim:
44
- raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
45
- self.hidden_size = params.hidden_size
46
- self.num_heads = params.num_heads
47
- self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
48
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
49
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
50
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
51
- self.guidance_in = (
52
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
53
- )
54
- self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
55
-
56
- self.double_blocks = nn.ModuleList(
57
- [
58
- DoubleStreamBlock(
59
- self.hidden_size,
60
- self.num_heads,
61
- mlp_ratio=params.mlp_ratio,
62
- qkv_bias=params.qkv_bias,
63
- )
64
- for _ in range(params.depth)
65
- ]
66
- )
67
-
68
- self.single_blocks = nn.ModuleList(
69
- [
70
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
71
- for _ in range(params.depth_single_blocks)
72
- ]
73
- )
74
-
75
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
76
-
77
- def forward(
78
- self,
79
- img: Tensor,
80
- img_ids: Tensor,
81
- txt: Tensor,
82
- txt_ids: Tensor,
83
- timesteps: Tensor,
84
- y: Tensor,
85
- guidance: Tensor | None = None,
86
- info = None,
87
- ) -> Tensor:
88
- if img.ndim != 3 or txt.ndim != 3:
89
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
90
-
91
- # running on sequences img
92
- img = self.img_in(img)
93
- vec = self.time_in(timestep_embedding(timesteps, 256))
94
- if self.params.guidance_embed:
95
- if guidance is None:
96
- raise ValueError("Didn't get guidance strength for guidance distilled model.")
97
- vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
98
- vec = vec + self.vector_in(y)
99
- txt = self.txt_in(txt)
100
-
101
- ids = torch.cat((txt_ids, img_ids), dim=1)
102
- pe = self.pe_embedder(ids)
103
-
104
- for block in self.double_blocks:
105
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe, info=info)
106
-
107
- cnt = 0
108
- img = torch.cat((txt, img), 1)
109
- info['type'] = 'single'
110
- for block in self.single_blocks:
111
- info['id'] = cnt
112
- img, info = block(img, vec=vec, pe=pe, info=info)
113
- cnt += 1
114
-
115
- img = img[:, txt.shape[1] :, ...]
116
-
117
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
118
- return img, info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/modules/__pycache__/autoencoder.cpython-310.pyc DELETED
Binary file (9.06 kB)
 
src/flux/modules/__pycache__/conditioner.cpython-310.pyc DELETED
Binary file (1.49 kB)
 
src/flux/modules/__pycache__/layers.cpython-310.pyc DELETED
Binary file (10.3 kB)
 
src/flux/modules/autoencoder.py DELETED
@@ -1,313 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- from einops import rearrange
5
- from torch import Tensor, nn
6
-
7
-
8
- @dataclass
9
- class AutoEncoderParams:
10
- resolution: int
11
- in_channels: int
12
- ch: int
13
- out_ch: int
14
- ch_mult: list[int]
15
- num_res_blocks: int
16
- z_channels: int
17
- scale_factor: float
18
- shift_factor: float
19
-
20
-
21
- def swish(x: Tensor) -> Tensor:
22
- return x * torch.sigmoid(x)
23
-
24
-
25
- class AttnBlock(nn.Module):
26
- def __init__(self, in_channels: int):
27
- super().__init__()
28
- self.in_channels = in_channels
29
-
30
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
-
32
- self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
- self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
- self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
- self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
-
37
- def attention(self, h_: Tensor) -> Tensor:
38
- h_ = self.norm(h_)
39
- q = self.q(h_)
40
- k = self.k(h_)
41
- v = self.v(h_)
42
-
43
- b, c, h, w = q.shape
44
- q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
- k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
- v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
- h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
-
49
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
-
51
- def forward(self, x: Tensor) -> Tensor:
52
- return x + self.proj_out(self.attention(x))
53
-
54
-
55
- class ResnetBlock(nn.Module):
56
- def __init__(self, in_channels: int, out_channels: int):
57
- super().__init__()
58
- self.in_channels = in_channels
59
- out_channels = in_channels if out_channels is None else out_channels
60
- self.out_channels = out_channels
61
-
62
- self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
- self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
- if self.in_channels != self.out_channels:
67
- self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
-
69
- def forward(self, x):
70
- h = x
71
- h = self.norm1(h)
72
- h = swish(h)
73
- h = self.conv1(h)
74
-
75
- h = self.norm2(h)
76
- h = swish(h)
77
- h = self.conv2(h)
78
-
79
- if self.in_channels != self.out_channels:
80
- x = self.nin_shortcut(x)
81
-
82
- return x + h
83
-
84
-
85
- class Downsample(nn.Module):
86
- def __init__(self, in_channels: int):
87
- super().__init__()
88
- # no asymmetric padding in torch conv, must do it ourselves
89
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
-
91
- def forward(self, x: Tensor):
92
- pad = (0, 1, 0, 1)
93
- x = nn.functional.pad(x, pad, mode="constant", value=0)
94
- x = self.conv(x)
95
- return x
96
-
97
-
98
- class Upsample(nn.Module):
99
- def __init__(self, in_channels: int):
100
- super().__init__()
101
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
-
103
- def forward(self, x: Tensor):
104
- x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
- x = self.conv(x)
106
- return x
107
-
108
-
109
- class Encoder(nn.Module):
110
- def __init__(
111
- self,
112
- resolution: int,
113
- in_channels: int,
114
- ch: int,
115
- ch_mult: list[int],
116
- num_res_blocks: int,
117
- z_channels: int,
118
- ):
119
- super().__init__()
120
- self.ch = ch
121
- self.num_resolutions = len(ch_mult)
122
- self.num_res_blocks = num_res_blocks
123
- self.resolution = resolution
124
- self.in_channels = in_channels
125
- # downsampling
126
- self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
-
128
- curr_res = resolution
129
- in_ch_mult = (1,) + tuple(ch_mult)
130
- self.in_ch_mult = in_ch_mult
131
- self.down = nn.ModuleList()
132
- block_in = self.ch
133
- for i_level in range(self.num_resolutions):
134
- block = nn.ModuleList()
135
- attn = nn.ModuleList()
136
- block_in = ch * in_ch_mult[i_level]
137
- block_out = ch * ch_mult[i_level]
138
- for _ in range(self.num_res_blocks):
139
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
- block_in = block_out
141
- down = nn.Module()
142
- down.block = block
143
- down.attn = attn
144
- if i_level != self.num_resolutions - 1:
145
- down.downsample = Downsample(block_in)
146
- curr_res = curr_res // 2
147
- self.down.append(down)
148
-
149
- # middle
150
- self.mid = nn.Module()
151
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
- self.mid.attn_1 = AttnBlock(block_in)
153
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
-
155
- # end
156
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
- self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
-
159
- def forward(self, x: Tensor) -> Tensor:
160
- # downsampling
161
- hs = [self.conv_in(x)]
162
- for i_level in range(self.num_resolutions):
163
- for i_block in range(self.num_res_blocks):
164
- h = self.down[i_level].block[i_block](hs[-1])
165
- if len(self.down[i_level].attn) > 0:
166
- h = self.down[i_level].attn[i_block](h)
167
- hs.append(h)
168
- if i_level != self.num_resolutions - 1:
169
- hs.append(self.down[i_level].downsample(hs[-1]))
170
-
171
- # middle
172
- h = hs[-1]
173
- h = self.mid.block_1(h)
174
- h = self.mid.attn_1(h)
175
- h = self.mid.block_2(h)
176
- # end
177
- h = self.norm_out(h)
178
- h = swish(h)
179
- h = self.conv_out(h)
180
- return h
181
-
182
-
183
- class Decoder(nn.Module):
184
- def __init__(
185
- self,
186
- ch: int,
187
- out_ch: int,
188
- ch_mult: list[int],
189
- num_res_blocks: int,
190
- in_channels: int,
191
- resolution: int,
192
- z_channels: int,
193
- ):
194
- super().__init__()
195
- self.ch = ch
196
- self.num_resolutions = len(ch_mult)
197
- self.num_res_blocks = num_res_blocks
198
- self.resolution = resolution
199
- self.in_channels = in_channels
200
- self.ffactor = 2 ** (self.num_resolutions - 1)
201
-
202
- # compute in_ch_mult, block_in and curr_res at lowest res
203
- block_in = ch * ch_mult[self.num_resolutions - 1]
204
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
- self.z_shape = (1, z_channels, curr_res, curr_res)
206
-
207
- # z to block_in
208
- self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
-
210
- # middle
211
- self.mid = nn.Module()
212
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
- self.mid.attn_1 = AttnBlock(block_in)
214
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
-
216
- # upsampling
217
- self.up = nn.ModuleList()
218
- for i_level in reversed(range(self.num_resolutions)):
219
- block = nn.ModuleList()
220
- attn = nn.ModuleList()
221
- block_out = ch * ch_mult[i_level]
222
- for _ in range(self.num_res_blocks + 1):
223
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
- block_in = block_out
225
- up = nn.Module()
226
- up.block = block
227
- up.attn = attn
228
- if i_level != 0:
229
- up.upsample = Upsample(block_in)
230
- curr_res = curr_res * 2
231
- self.up.insert(0, up) # prepend to get consistent order
232
-
233
- # end
234
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
- self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
-
237
- def forward(self, z: Tensor) -> Tensor:
238
- # z to block_in
239
- h = self.conv_in(z)
240
-
241
- # middle
242
- h = self.mid.block_1(h)
243
- h = self.mid.attn_1(h)
244
- h = self.mid.block_2(h)
245
-
246
- # upsampling
247
- for i_level in reversed(range(self.num_resolutions)):
248
- for i_block in range(self.num_res_blocks + 1):
249
- h = self.up[i_level].block[i_block](h)
250
- if len(self.up[i_level].attn) > 0:
251
- h = self.up[i_level].attn[i_block](h)
252
- if i_level != 0:
253
- h = self.up[i_level].upsample(h)
254
-
255
- # end
256
- h = self.norm_out(h)
257
- h = swish(h)
258
- h = self.conv_out(h)
259
- return h
260
-
261
-
262
- class DiagonalGaussian(nn.Module):
263
- def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
- super().__init__()
265
- self.sample = sample
266
- self.chunk_dim = chunk_dim
267
-
268
- def forward(self, z: Tensor) -> Tensor:
269
- mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
- # import pdb;pdb.set_trace()
271
- if self.sample:
272
- std = torch.exp(0.5 * logvar)
273
- return mean #+ std * torch.randn_like(mean)
274
- else:
275
- return mean
276
-
277
-
278
- class AutoEncoder(nn.Module):
279
- def __init__(self, params: AutoEncoderParams):
280
- super().__init__()
281
- self.encoder = Encoder(
282
- resolution=params.resolution,
283
- in_channels=params.in_channels,
284
- ch=params.ch,
285
- ch_mult=params.ch_mult,
286
- num_res_blocks=params.num_res_blocks,
287
- z_channels=params.z_channels,
288
- )
289
- self.decoder = Decoder(
290
- resolution=params.resolution,
291
- in_channels=params.in_channels,
292
- ch=params.ch,
293
- out_ch=params.out_ch,
294
- ch_mult=params.ch_mult,
295
- num_res_blocks=params.num_res_blocks,
296
- z_channels=params.z_channels,
297
- )
298
- self.reg = DiagonalGaussian()
299
-
300
- self.scale_factor = params.scale_factor
301
- self.shift_factor = params.shift_factor
302
-
303
- def encode(self, x: Tensor) -> Tensor:
304
- z = self.reg(self.encoder(x))
305
- z = self.scale_factor * (z - self.shift_factor)
306
- return z
307
-
308
- def decode(self, z: Tensor) -> Tensor:
309
- z = z / self.scale_factor + self.shift_factor
310
- return self.decoder(z)
311
-
312
- def forward(self, x: Tensor) -> Tensor:
313
- return self.decode(self.encode(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/modules/conditioner.py DELETED
@@ -1,38 +0,0 @@
1
- from torch import Tensor, nn
2
- from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
3
- T5Tokenizer)
4
-
5
-
6
- class HFEmbedder(nn.Module):
7
- def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
8
- super().__init__()
9
- self.is_clip = is_clip
10
- self.max_length = max_length
11
- self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
12
-
13
- if self.is_clip:
14
- self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
15
- self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
16
- else:
17
- self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
18
- self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
19
-
20
- self.hf_module = self.hf_module.eval().requires_grad_(False)
21
-
22
- def forward(self, text: list[str]) -> Tensor:
23
- batch_encoding = self.tokenizer(
24
- text,
25
- truncation=True,
26
- max_length=self.max_length,
27
- return_length=False,
28
- return_overflowing_tokens=False,
29
- padding="max_length",
30
- return_tensors="pt",
31
- )
32
-
33
- outputs = self.hf_module(
34
- input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
35
- attention_mask=None,
36
- output_hidden_states=False,
37
- )
38
- return outputs[self.output_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/modules/layers.py DELETED
@@ -1,280 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
-
4
- import torch
5
- from einops import rearrange
6
- from torch import Tensor, nn
7
-
8
- from flux.math import attention, rope
9
-
10
- import os
11
-
12
- class EmbedND(nn.Module):
13
- def __init__(self, dim: int, theta: int, axes_dim: list[int]):
14
- super().__init__()
15
- self.dim = dim
16
- self.theta = theta
17
- self.axes_dim = axes_dim
18
-
19
- def forward(self, ids: Tensor) -> Tensor:
20
- n_axes = ids.shape[-1]
21
- emb = torch.cat(
22
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
23
- dim=-3,
24
- )
25
-
26
- return emb.unsqueeze(1)
27
-
28
-
29
- def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
30
- """
31
- Create sinusoidal timestep embeddings.
32
- :param t: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- :param dim: the dimension of the output.
35
- :param max_period: controls the minimum frequency of the embeddings.
36
- :return: an (N, D) Tensor of positional embeddings.
37
- """
38
- t = time_factor * t
39
- half = dim // 2
40
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
41
- t.device
42
- )
43
-
44
- args = t[:, None].float() * freqs[None]
45
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
- if dim % 2:
47
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
- if torch.is_floating_point(t):
49
- embedding = embedding.to(t)
50
- return embedding
51
-
52
-
53
- class MLPEmbedder(nn.Module):
54
- def __init__(self, in_dim: int, hidden_dim: int):
55
- super().__init__()
56
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
57
- self.silu = nn.SiLU()
58
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
59
-
60
- def forward(self, x: Tensor) -> Tensor:
61
- return self.out_layer(self.silu(self.in_layer(x)))
62
-
63
-
64
- class RMSNorm(torch.nn.Module):
65
- def __init__(self, dim: int):
66
- super().__init__()
67
- self.scale = nn.Parameter(torch.ones(dim))
68
-
69
- def forward(self, x: Tensor):
70
- x_dtype = x.dtype
71
- x = x.float()
72
- rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
73
- return (x * rrms).to(dtype=x_dtype) * self.scale
74
-
75
-
76
- class QKNorm(torch.nn.Module):
77
- def __init__(self, dim: int):
78
- super().__init__()
79
- self.query_norm = RMSNorm(dim)
80
- self.key_norm = RMSNorm(dim)
81
-
82
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
83
- q = self.query_norm(q)
84
- k = self.key_norm(k)
85
- return q.to(v), k.to(v)
86
-
87
-
88
- class SelfAttention(nn.Module):
89
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
90
- super().__init__()
91
- self.num_heads = num_heads
92
- head_dim = dim // num_heads
93
-
94
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
95
- self.norm = QKNorm(head_dim)
96
- self.proj = nn.Linear(dim, dim)
97
-
98
- def forward(self, x: Tensor, pe: Tensor) -> Tensor:
99
- qkv = self.qkv(x)
100
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
101
- q, k = self.norm(q, k, v)
102
- x = attention(q, k, v, pe=pe)
103
- x = self.proj(x)
104
- return x
105
-
106
-
107
- @dataclass
108
- class ModulationOut:
109
- shift: Tensor
110
- scale: Tensor
111
- gate: Tensor
112
-
113
-
114
- class Modulation(nn.Module):
115
- def __init__(self, dim: int, double: bool):
116
- super().__init__()
117
- self.is_double = double
118
- self.multiplier = 6 if double else 3
119
- self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
120
-
121
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
122
- out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
123
-
124
- return (
125
- ModulationOut(*out[:3]),
126
- ModulationOut(*out[3:]) if self.is_double else None,
127
- )
128
-
129
-
130
- class DoubleStreamBlock(nn.Module):
131
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
132
- super().__init__()
133
-
134
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
135
- self.num_heads = num_heads
136
- self.hidden_size = hidden_size
137
- self.img_mod = Modulation(hidden_size, double=True)
138
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
139
- self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
140
-
141
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
- self.img_mlp = nn.Sequential(
143
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
144
- nn.GELU(approximate="tanh"),
145
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
146
- )
147
-
148
- self.txt_mod = Modulation(hidden_size, double=True)
149
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
150
- self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
151
-
152
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
- self.txt_mlp = nn.Sequential(
154
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
155
- nn.GELU(approximate="tanh"),
156
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
157
- )
158
-
159
- def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, info) -> tuple[Tensor, Tensor]:
160
- img_mod1, img_mod2 = self.img_mod(vec)
161
- txt_mod1, txt_mod2 = self.txt_mod(vec)
162
-
163
- # prepare image for attention
164
- img_modulated = self.img_norm1(img)
165
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
166
- img_qkv = self.img_attn.qkv(img_modulated)
167
- img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
168
-
169
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
170
-
171
- # if info['inject']:
172
- # if info['inverse']:
173
- # print("!save! ",info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'])
174
- # torch.save(img_q, info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'Q' + '.pth')
175
- # if not info['inverse']:
176
- # print("!load! ", info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'])
177
- # img_q = torch.load(info['feature_path'] + str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'Q' + '.pth', weights_only=True)
178
-
179
- # prepare txt for attention
180
- txt_modulated = self.txt_norm1(txt)
181
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
182
- txt_qkv = self.txt_attn.qkv(txt_modulated)
183
- txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
184
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
185
-
186
- # run actual attention
187
- q = torch.cat((txt_q, img_q), dim=2) #[8, 24, 512, 128] + [8, 24, 900, 128] -> [8, 24, 1412, 128]
188
- k = torch.cat((txt_k, img_k), dim=2)
189
- v = torch.cat((txt_v, img_v), dim=2)
190
- # import pdb;pdb.set_trace()
191
- attn = attention(q, k, v, pe=pe)
192
-
193
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
194
-
195
- # calculate the img bloks
196
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
197
- img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
198
-
199
- # calculate the txt bloks
200
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
201
- txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
202
- return img, txt
203
-
204
-
205
- class SingleStreamBlock(nn.Module):
206
- """
207
- A DiT block with parallel linear layers as described in
208
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
209
- """
210
-
211
- def __init__(
212
- self,
213
- hidden_size: int,
214
- num_heads: int,
215
- mlp_ratio: float = 4.0,
216
- qk_scale: float | None = None,
217
- ):
218
- super().__init__()
219
- self.hidden_dim = hidden_size
220
- self.num_heads = num_heads
221
- head_dim = hidden_size // num_heads
222
- self.scale = qk_scale or head_dim**-0.5
223
-
224
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
225
- # qkv and mlp_in
226
- self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
227
- # proj and mlp_out
228
- self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
229
-
230
- self.norm = QKNorm(head_dim)
231
-
232
- self.hidden_size = hidden_size
233
- self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
234
-
235
- self.mlp_act = nn.GELU(approximate="tanh")
236
- self.modulation = Modulation(hidden_size, double=False)
237
-
238
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor, info) -> Tensor:
239
- mod, _ = self.modulation(vec)
240
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
241
- qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
242
-
243
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
244
- q, k = self.norm(q, k, v)
245
-
246
- # Note: If the memory of your device is not enough, you may consider uncomment the following code.
247
- # if info['inject'] and info['id'] > 19:
248
- # store_path = os.path.join(info['feature_path'], str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V' + '.pth')
249
- # if info['inverse']:
250
- # torch.save(v, store_path)
251
- # if not info['inverse']:
252
- # v = torch.load(store_path, weights_only=True)
253
-
254
- # Save the features in the memory
255
- if info['inject'] and info['id'] > 19:
256
- feature_name = str(info['t']) + '_' + str(info['second_order']) + '_' + str(info['id']) + '_' + info['type'] + '_' + 'V'
257
- if info['inverse']:
258
- info['feature'][feature_name] = v.cpu()
259
- else:
260
- v = info['feature'][feature_name].cuda()
261
-
262
- # compute attention
263
- attn = attention(q, k, v, pe=pe)
264
- # compute activation in mlp stream, cat again and run second linear layer
265
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
266
- return x + mod.gate * output, info
267
-
268
-
269
- class LastLayer(nn.Module):
270
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
271
- super().__init__()
272
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
273
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
274
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
275
-
276
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
277
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
278
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
279
- x = self.linear(x)
280
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/sampling.py DELETED
@@ -1,147 +0,0 @@
1
- import math
2
- from typing import Callable
3
-
4
- import torch
5
- from einops import rearrange, repeat
6
- from torch import Tensor
7
-
8
- from .model import Flux
9
- from .modules.conditioner import HFEmbedder
10
-
11
-
12
- def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
13
- bs, c, h, w = img.shape
14
- if bs == 1 and not isinstance(prompt, str):
15
- bs = len(prompt)
16
-
17
- img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
18
- if img.shape[0] == 1 and bs > 1:
19
- img = repeat(img, "1 ... -> bs ...", bs=bs)
20
-
21
- img_ids = torch.zeros(h // 2, w // 2, 3)
22
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
23
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
24
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
25
-
26
- if isinstance(prompt, str):
27
- prompt = [prompt]
28
- txt = t5(prompt)
29
- if txt.shape[0] == 1 and bs > 1:
30
- txt = repeat(txt, "1 ... -> bs ...", bs=bs)
31
- txt_ids = torch.zeros(bs, txt.shape[1], 3)
32
-
33
- vec = clip(prompt)
34
- if vec.shape[0] == 1 and bs > 1:
35
- vec = repeat(vec, "1 ... -> bs ...", bs=bs)
36
-
37
- return {
38
- "img": img,
39
- "img_ids": img_ids.to(img.device),
40
- "txt": txt.to(img.device),
41
- "txt_ids": txt_ids.to(img.device),
42
- "vec": vec.to(img.device),
43
- }
44
-
45
-
46
- def time_shift(mu: float, sigma: float, t: Tensor):
47
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
48
-
49
-
50
- def get_lin_function(
51
- x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
52
- ) -> Callable[[float], float]:
53
- m = (y2 - y1) / (x2 - x1)
54
- b = y1 - m * x1
55
- return lambda x: m * x + b
56
-
57
-
58
- def get_schedule(
59
- num_steps: int,
60
- image_seq_len: int,
61
- base_shift: float = 0.5,
62
- max_shift: float = 1.15,
63
- shift: bool = True,
64
- ) -> list[float]:
65
- # extra step for zero
66
- timesteps = torch.linspace(1, 0, num_steps + 1)
67
-
68
- # shifting the schedule to favor high timesteps for higher signal images
69
- if shift:
70
- # estimate mu based on linear estimation between two points
71
- mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
72
- timesteps = time_shift(mu, 1.0, timesteps)
73
-
74
- return timesteps.tolist()
75
-
76
-
77
- def denoise(
78
- model: Flux,
79
- # model input
80
- img: Tensor,
81
- img_ids: Tensor,
82
- txt: Tensor,
83
- txt_ids: Tensor,
84
- vec: Tensor,
85
- # sampling parameters
86
- timesteps: list[float],
87
- inverse,
88
- info,
89
- guidance: float = 4.0
90
- ):
91
- # this is ignored for schnell
92
- inject_list = [True] * info['inject_step'] + [False] * (len(timesteps[:-1]) - info['inject_step'])
93
-
94
- if inverse:
95
- timesteps = timesteps[::-1]
96
- inject_list = inject_list[::-1]
97
- guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
98
-
99
- step_list = []
100
- for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
101
- t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
102
- info['t'] = t_prev if inverse else t_curr
103
- info['inverse'] = inverse
104
- info['second_order'] = False
105
- info['inject'] = inject_list[i]
106
-
107
- pred, info = model(
108
- img=img,
109
- img_ids=img_ids,
110
- txt=txt,
111
- txt_ids=txt_ids,
112
- y=vec,
113
- timesteps=t_vec,
114
- guidance=guidance_vec,
115
- info=info
116
- )
117
-
118
- img_mid = img + (t_prev - t_curr) / 2 * pred
119
-
120
- t_vec_mid = torch.full((img.shape[0],), (t_curr + (t_prev - t_curr) / 2), dtype=img.dtype, device=img.device)
121
- info['second_order'] = True
122
- pred_mid, info = model(
123
- img=img_mid,
124
- img_ids=img_ids,
125
- txt=txt,
126
- txt_ids=txt_ids,
127
- y=vec,
128
- timesteps=t_vec_mid,
129
- guidance=guidance_vec,
130
- info=info
131
- )
132
-
133
- first_order = (pred_mid - pred) / ((t_prev - t_curr) / 2)
134
- img = img + (t_prev - t_curr) * pred + 0.5 * (t_prev - t_curr) ** 2 * first_order
135
-
136
- return img, info
137
-
138
-
139
- def unpack(x: Tensor, height: int, width: int) -> Tensor:
140
- return rearrange(
141
- x,
142
- "b (h w) (c ph pw) -> b c (h ph) (w pw)",
143
- h=math.ceil(height / 16),
144
- w=math.ceil(width / 16),
145
- ph=2,
146
- pw=2,
147
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/flux/util.py DELETED
@@ -1,201 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
-
4
- import torch
5
- from einops import rearrange
6
- from huggingface_hub import hf_hub_download
7
- from imwatermark import WatermarkEncoder
8
- from safetensors.torch import load_file as load_sft
9
-
10
- from flux.model import Flux, FluxParams
11
- from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
12
- from flux.modules.conditioner import HFEmbedder
13
-
14
-
15
- @dataclass
16
- class ModelSpec:
17
- params: FluxParams
18
- ae_params: AutoEncoderParams
19
- ckpt_path: str | None
20
- ae_path: str | None
21
- repo_id: str | None
22
- repo_flow: str | None
23
- repo_ae: str | None
24
-
25
- configs = {
26
- "flux-dev": ModelSpec(
27
- repo_id="black-forest-labs/FLUX.1-dev",
28
- repo_flow="flux1-dev.safetensors",
29
- repo_ae="ae.safetensors",
30
- ckpt_path=os.getenv("FLUX_DEV"),
31
- params=FluxParams(
32
- in_channels=64,
33
- vec_in_dim=768,
34
- context_in_dim=4096,
35
- hidden_size=3072,
36
- mlp_ratio=4.0,
37
- num_heads=24,
38
- depth=19,
39
- depth_single_blocks=38,
40
- axes_dim=[16, 56, 56],
41
- theta=10_000,
42
- qkv_bias=True,
43
- guidance_embed=True,
44
- ),
45
- ae_path=os.getenv("AE"),
46
- ae_params=AutoEncoderParams(
47
- resolution=256,
48
- in_channels=3,
49
- ch=128,
50
- out_ch=3,
51
- ch_mult=[1, 2, 4, 4],
52
- num_res_blocks=2,
53
- z_channels=16,
54
- scale_factor=0.3611,
55
- shift_factor=0.1159,
56
- ),
57
- ),
58
- "flux-schnell": ModelSpec(
59
- repo_id="black-forest-labs/FLUX.1-schnell",
60
- repo_flow="flux1-schnell.safetensors",
61
- repo_ae="ae.safetensors",
62
- ckpt_path=os.getenv("FLUX_SCHNELL"),
63
- params=FluxParams(
64
- in_channels=64,
65
- vec_in_dim=768,
66
- context_in_dim=4096,
67
- hidden_size=3072,
68
- mlp_ratio=4.0,
69
- num_heads=24,
70
- depth=19,
71
- depth_single_blocks=38,
72
- axes_dim=[16, 56, 56],
73
- theta=10_000,
74
- qkv_bias=True,
75
- guidance_embed=False,
76
- ),
77
- ae_path=os.getenv("AE"),
78
- ae_params=AutoEncoderParams(
79
- resolution=256,
80
- in_channels=3,
81
- ch=128,
82
- out_ch=3,
83
- ch_mult=[1, 2, 4, 4],
84
- num_res_blocks=2,
85
- z_channels=16,
86
- scale_factor=0.3611,
87
- shift_factor=0.1159,
88
- ),
89
- ),
90
- }
91
-
92
-
93
- def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
94
- if len(missing) > 0 and len(unexpected) > 0:
95
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
96
- print("\n" + "-" * 79 + "\n")
97
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
98
- elif len(missing) > 0:
99
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
100
- elif len(unexpected) > 0:
101
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
102
-
103
-
104
- def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True):
105
- # Loading Flux
106
- print("Init model")
107
-
108
- ckpt_path = configs[name].ckpt_path
109
- if (
110
- ckpt_path is None
111
- and configs[name].repo_id is not None
112
- and configs[name].repo_flow is not None
113
- and hf_download
114
- ):
115
- ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
116
-
117
- with torch.device("meta" if ckpt_path is not None else device):
118
- model = Flux(configs[name].params).to(torch.bfloat16)
119
-
120
- if ckpt_path is not None:
121
- print("Loading checkpoint")
122
- # load_sft doesn't support torch.device
123
- sd = load_sft(ckpt_path, device=str(device))
124
- missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
125
- print_load_warning(missing, unexpected)
126
- return model
127
-
128
-
129
- def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
130
- # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131
- return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, is_clip=False, torch_dtype=torch.bfloat16).to(device)
132
-
133
-
134
- def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
135
- return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, is_clip=True, torch_dtype=torch.bfloat16).to(device)
136
-
137
-
138
- def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
139
- ckpt_path = configs[name].ae_path
140
- if (
141
- ckpt_path is None
142
- and configs[name].repo_id is not None
143
- and configs[name].repo_ae is not None
144
- and hf_download
145
- ):
146
- ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
147
-
148
- # Loading the autoencoder
149
- print("Init AE")
150
- with torch.device("meta" if ckpt_path is not None else device):
151
- ae = AutoEncoder(configs[name].ae_params)
152
-
153
- if ckpt_path is not None:
154
- sd = load_sft(ckpt_path, device=str(device))
155
- missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
156
- print_load_warning(missing, unexpected)
157
- return ae
158
-
159
-
160
- class WatermarkEmbedder:
161
- def __init__(self, watermark):
162
- self.watermark = watermark
163
- self.num_bits = len(WATERMARK_BITS)
164
- self.encoder = WatermarkEncoder()
165
- self.encoder.set_watermark("bits", self.watermark)
166
-
167
- def __call__(self, image: torch.Tensor) -> torch.Tensor:
168
- """
169
- Adds a predefined watermark to the input image
170
-
171
- Args:
172
- image: ([N,] B, RGB, H, W) in range [-1, 1]
173
-
174
- Returns:
175
- same as input but watermarked
176
- """
177
- image = 0.5 * image + 0.5
178
- squeeze = len(image.shape) == 4
179
- if squeeze:
180
- image = image[None, ...]
181
- n = image.shape[0]
182
- image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
183
- # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
184
- # watermarking libary expects input as cv2 BGR format
185
- for k in range(image_np.shape[0]):
186
- image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
187
- image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
188
- image.device
189
- )
190
- image = torch.clamp(image / 255, min=0.0, max=1.0)
191
- if squeeze:
192
- image = image[0]
193
- image = 2 * image - 1
194
- return image
195
-
196
-
197
- # A fixed 48-bit message that was chosen at random
198
- WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
199
- # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
200
- WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
201
- embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/gradio_demo.py DELETED
@@ -1,243 +0,0 @@
1
- import os
2
- import re
3
- import time
4
- from io import BytesIO
5
- import uuid
6
- from dataclasses import dataclass
7
- from glob import iglob
8
- import argparse
9
- from einops import rearrange
10
- from fire import Fire
11
- from PIL import ExifTags, Image
12
-
13
-
14
- import torch
15
- import torch.nn.functional as F
16
- import gradio as gr
17
- import numpy as np
18
- from transformers import pipeline
19
-
20
- from flux.sampling import denoise, get_schedule, prepare, unpack
21
- from flux.util import (configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5)
22
-
23
- @dataclass
24
- class SamplingOptions:
25
- source_prompt: str
26
- target_prompt: str
27
- # prompt: str
28
- width: int
29
- height: int
30
- num_steps: int
31
- guidance: float
32
- seed: int | None
33
-
34
- @torch.inference_mode()
35
- def encode(init_image, torch_device, ae):
36
- init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
37
- init_image = init_image.unsqueeze(0)
38
- init_image = init_image.to(torch_device)
39
- with torch.no_grad():
40
- init_image = ae.encode(init_image.to()).to(torch.bfloat16)
41
- return init_image
42
-
43
-
44
- class FluxEditor:
45
- def __init__(self, args):
46
- self.args = args
47
- self.device = torch.device(args.device)
48
- self.offload = args.offload
49
- self.name = args.name
50
- self.is_schnell = args.name == "flux-schnell"
51
-
52
- self.feature_path = 'feature'
53
- self.output_dir = 'result'
54
- self.add_sampling_metadata = True
55
-
56
- if self.name not in configs:
57
- available = ", ".join(configs.keys())
58
- raise ValueError(f"Got unknown model name: {name}, chose from {available}")
59
-
60
- # init all components
61
- self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
62
- self.clip = load_clip(self.device)
63
- self.model = load_flow_model(self.name, device="cpu" if self.offload else self.device)
64
- self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
65
- self.t5.eval()
66
- self.clip.eval()
67
- self.ae.eval()
68
- self.model.eval()
69
-
70
- if self.offload:
71
- self.model.cpu()
72
- torch.cuda.empty_cache()
73
- self.ae.encoder.to(self.device)
74
-
75
- @torch.inference_mode()
76
- def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
77
- torch.cuda.empty_cache()
78
- seed = None
79
- # if seed == -1:
80
- # seed = None
81
-
82
- shape = init_image.shape
83
-
84
- new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
85
- new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
86
-
87
- init_image = init_image[:new_h, :new_w, :]
88
-
89
- width, height = init_image.shape[0], init_image.shape[1]
90
- init_image = encode(init_image, self.device, self.ae)
91
-
92
- print(init_image.shape)
93
-
94
- rng = torch.Generator(device="cpu")
95
- opts = SamplingOptions(
96
- source_prompt=source_prompt,
97
- target_prompt=target_prompt,
98
- width=width,
99
- height=height,
100
- num_steps=num_steps,
101
- guidance=guidance,
102
- seed=seed,
103
- )
104
- if opts.seed is None:
105
- opts.seed = torch.Generator(device="cpu").seed()
106
-
107
- print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
108
- t0 = time.perf_counter()
109
-
110
- opts.seed = None
111
- if self.offload:
112
- self.ae = self.ae.cpu()
113
- torch.cuda.empty_cache()
114
- self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
115
-
116
- #############inverse#######################
117
- info = {}
118
- info['feature'] = {}
119
- info['inject_step'] = inject_step
120
-
121
- if not os.path.exists(self.feature_path):
122
- os.mkdir(self.feature_path)
123
-
124
- with torch.no_grad():
125
- inp = prepare(self.t5, self.clip, init_image, prompt=opts.source_prompt)
126
- inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
127
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
128
-
129
- # offload TEs to CPU, load model to gpu
130
- if self.offload:
131
- self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
132
- torch.cuda.empty_cache()
133
- self.model = self.model.to(self.device)
134
-
135
- # inversion initial noise
136
- with torch.no_grad():
137
- z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
138
-
139
- inp_target["img"] = z
140
-
141
- timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
142
-
143
- # denoise initial noise
144
- x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
145
-
146
- # offload model, load autoencoder to gpu
147
- if self.offload:
148
- self.model.cpu()
149
- torch.cuda.empty_cache()
150
- self.ae.decoder.to(x.device)
151
-
152
- # decode latents to pixel space
153
- x = unpack(x.float(), opts.width, opts.height)
154
-
155
- output_name = os.path.join(self.output_dir, "img_{idx}.jpg")
156
- if not os.path.exists(self.output_dir):
157
- os.makedirs(self.output_dir)
158
- idx = 0
159
- else:
160
- fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
161
- if len(fns) > 0:
162
- idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
163
- else:
164
- idx = 0
165
-
166
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
167
- x = self.ae.decode(x)
168
-
169
- if torch.cuda.is_available():
170
- torch.cuda.synchronize()
171
- t1 = time.perf_counter()
172
-
173
- fn = output_name.format(idx=idx)
174
- print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
175
- # bring into PIL format and save
176
- x = x.clamp(-1, 1)
177
- x = embed_watermark(x.float())
178
- x = rearrange(x[0], "c h w -> h w c")
179
-
180
- img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
181
- exif_data = Image.Exif()
182
- exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
183
- exif_data[ExifTags.Base.Make] = "Black Forest Labs"
184
- exif_data[ExifTags.Base.Model] = self.name
185
- if self.add_sampling_metadata:
186
- exif_data[ExifTags.Base.ImageDescription] = source_prompt
187
- img.save(fn, exif=exif_data, quality=95, subsampling=0)
188
-
189
-
190
- print("End Edit")
191
- return img
192
-
193
-
194
-
195
- def create_demo(model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False):
196
- editor = FluxEditor(args)
197
- is_schnell = model_name == "flux-schnell"
198
-
199
- with gr.Blocks() as demo:
200
- gr.Markdown(f"# RF-Edit Demo (FLUX for image editing)")
201
-
202
- with gr.Row():
203
- with gr.Column():
204
- source_prompt = gr.Textbox(label="Source Prompt", value="")
205
- target_prompt = gr.Textbox(label="Target Prompt", value="")
206
- init_image = gr.Image(label="Input Image", visible=True)
207
-
208
-
209
- generate_btn = gr.Button("Generate")
210
-
211
- with gr.Column():
212
- with gr.Accordion("Advanced Options", open=True):
213
- num_steps = gr.Slider(1, 30, 25, step=1, label="Number of steps")
214
- inject_step = gr.Slider(1, 15, 5, step=1, label="Number of inject steps")
215
- guidance = gr.Slider(1.0, 10.0, 2, step=0.1, label="Guidance", interactive=not is_schnell)
216
- # seed = gr.Textbox(0, label="Seed (-1 for random)", visible=False)
217
- # add_sampling_metadata = gr.Checkbox(label="Add sampling parameters to metadata?", value=False)
218
-
219
- output_image = gr.Image(label="Generated Image")
220
-
221
- generate_btn.click(
222
- fn=editor.edit,
223
- inputs=[init_image, source_prompt, target_prompt, num_steps, inject_step, guidance],
224
- outputs=[output_image]
225
- )
226
-
227
-
228
- return demo
229
-
230
-
231
- if __name__ == "__main__":
232
- import argparse
233
- parser = argparse.ArgumentParser(description="Flux")
234
- parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
235
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
236
- parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
237
- parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
238
-
239
- parser.add_argument("--port", type=int, default=41035)
240
- args = parser.parse_args()
241
-
242
- demo = create_demo(args.name, args.device, args.offload)
243
- demo.launch(server_name='0.0.0.0', share=args.share, server_port=args.port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/run_art_batman.sh DELETED
@@ -1,8 +0,0 @@
1
- python edit.py --source_prompt "" \
2
- --target_prompt "a vivid depiction of the Batman, featuring rich, dynamic colors, and a blend of realistic and abstract elements with dynamic splatter art." \
3
- --guidance 2 \
4
- --source_img_dir '/examples/source/art.jpg' \
5
- --num_steps 25 \
6
- --inject 5 \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/art/'
 
 
 
 
 
 
 
 
 
src/run_art_mari.sh DELETED
@@ -1,8 +0,0 @@
1
- python edit.py --source_prompt "" \
2
- --target_prompt "a vivid depiction of the Marilyn Monroe, featuring rich, dynamic colors, and a blend of realistic and abstract elements with dynamic splatter art." \
3
- --guidance 2 \
4
- --source_img_dir '/examples/source/art.jpg' \
5
- --num_steps 25 \
6
- --inject 3 \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/art/'
 
 
 
 
 
 
 
 
 
src/run_boy.sh DELETED
@@ -1,10 +0,0 @@
1
- python edit.py --source_prompt "A young boy is playing with a toy airplane on the grassy front lawn of a suburban house, with a blue sky and fluffy clouds above." \
2
- --target_prompt "A young boy is playing with a toy airplane on the grassy front lawn of a suburban house, with a small brown dog playing beside him, and a blue sky with fluffy clouds above." \
3
- --guidance 2 \
4
- --source_img_dir 'examples/source/boy.jpg' \
5
- --num_steps 15 --offload \
6
- --inject 2 \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/dog'
9
-
10
-
 
 
 
 
 
 
 
 
 
 
 
src/run_cartoon_ein.sh DELETED
@@ -1,8 +0,0 @@
1
- python edit.py --source_prompt "" \
2
- --target_prompt "a cartoon style Albert Einstein raising his left hand " \
3
- --guidance 2 \
4
- --source_img_dir 'examples/source/cartoon.jpg' \
5
- --num_steps 25 \
6
- --inject 2 \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/cartoon/'
 
 
 
 
 
 
 
 
 
src/run_cartoon_herry.sh DELETED
@@ -1,8 +0,0 @@
1
- python edit.py --source_prompt "" \
2
- --target_prompt "a cartoon style Herry Potter raising his left hand " \
3
- --guidance 2 \
4
- --source_img_dir 'examples/source/cartoon.jpg' \
5
- --num_steps 25 \
6
- --inject 2 \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/cartoon/'
 
 
 
 
 
 
 
 
 
src/run_hiking.sh DELETED
@@ -1,9 +0,0 @@
1
- python edit.py --source_prompt "A woman hiking on a trail with mountains in the distance, carrying a backpack." \
2
- --target_prompt "A woman hiking on a trail with mountains in the distance, carrying a backpack and holding a hiking stick." \
3
- --guidance 2 \
4
- --source_img_dir 'examples/source/hiking.jpg' \
5
- --num_steps 15 \
6
- --inject 2 --offload \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/hiking/'
9
-
 
 
 
 
 
 
 
 
 
 
src/run_horse.sh DELETED
@@ -1,9 +0,0 @@
1
-
2
- python edit.py --source_prompt "A young boy is riding a brown horse in a countryside field, with a large tree in the background." \
3
- --target_prompt "A young boy is riding a camel in a countryside field, with a large tree in the background." \
4
- --guidance 2 \
5
- --source_img_dir 'examples/source/horse.jpg' \
6
- --num_steps 15 \
7
- --inject 3 --offload \
8
- --name 'flux-dev' \
9
- --output_dir 'examples/edit-result/horse/'
 
 
 
 
 
 
 
 
 
 
src/run_nobel_biden.sh DELETED
@@ -1,8 +0,0 @@
1
- python edit.py --source_prompt "" \
2
- --target_prompt "A minimalistic line-drawing portrait of Joe Biden with black lines and light brown shadow" \
3
- --guidance 2.5 \
4
- --source_img_dir 'examples/source/nobel.jpg' \
5
- --num_steps 25 \
6
- --inject 2 \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/nobel/'
 
 
 
 
 
 
 
 
 
src/run_nobel_trump.sh DELETED
@@ -1,8 +0,0 @@
1
- python edit.py --source_prompt "" \
2
- --target_prompt "A minimalistic line-drawing portrait of Donald Trump with black lines and brown shadow" \
3
- --guidance 2.5 \
4
- --source_img_dir 'examples/source/nobel.jpg' \
5
- --num_steps 25 \
6
- --inject 3 \
7
- --name 'flux-dev' \
8
- --output_dir 'examples/edit-result/nobel/'