Update flux/util.py
Browse files- flux/util.py +0 -45
flux/util.py
CHANGED
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
4 |
import torch
|
5 |
from einops import rearrange
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
-
from imwatermark import WatermarkEncoder
|
8 |
from safetensors.torch import load_file as load_sft
|
9 |
|
10 |
from flux.model import Flux, FluxParams
|
@@ -155,47 +154,3 @@ def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEn
|
|
155 |
missing, unexpected = ae.load_state_dict(sd, strict=False)
|
156 |
print_load_warning(missing, unexpected)
|
157 |
return ae
|
158 |
-
|
159 |
-
|
160 |
-
class WatermarkEmbedder:
|
161 |
-
def __init__(self, watermark):
|
162 |
-
self.watermark = watermark
|
163 |
-
self.num_bits = len(WATERMARK_BITS)
|
164 |
-
self.encoder = WatermarkEncoder()
|
165 |
-
self.encoder.set_watermark("bits", self.watermark)
|
166 |
-
|
167 |
-
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
168 |
-
"""
|
169 |
-
Adds a predefined watermark to the input image
|
170 |
-
|
171 |
-
Args:
|
172 |
-
image: ([N,] B, RGB, H, W) in range [-1, 1]
|
173 |
-
|
174 |
-
Returns:
|
175 |
-
same as input but watermarked
|
176 |
-
"""
|
177 |
-
image = 0.5 * image + 0.5
|
178 |
-
squeeze = len(image.shape) == 4
|
179 |
-
if squeeze:
|
180 |
-
image = image[None, ...]
|
181 |
-
n = image.shape[0]
|
182 |
-
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
|
183 |
-
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
184 |
-
# watermarking libary expects input as cv2 BGR format
|
185 |
-
for k in range(image_np.shape[0]):
|
186 |
-
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
187 |
-
image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
|
188 |
-
image.device
|
189 |
-
)
|
190 |
-
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
191 |
-
if squeeze:
|
192 |
-
image = image[0]
|
193 |
-
image = 2 * image - 1
|
194 |
-
return image
|
195 |
-
|
196 |
-
|
197 |
-
# A fixed 48-bit message that was choosen at random
|
198 |
-
WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
|
199 |
-
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
200 |
-
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
201 |
-
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
|
|
4 |
import torch
|
5 |
from einops import rearrange
|
6 |
from huggingface_hub import hf_hub_download
|
|
|
7 |
from safetensors.torch import load_file as load_sft
|
8 |
|
9 |
from flux.model import Flux, FluxParams
|
|
|
154 |
missing, unexpected = ae.load_state_dict(sd, strict=False)
|
155 |
print_load_warning(missing, unexpected)
|
156 |
return ae
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|