import os from dataclasses import dataclass import torch import json import cv2 import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from safetensors import safe_open from safetensors.torch import load_file as load_sft from optimum.quanto import requantize from .model import Flux, FluxParams from .controlnet import ControlNetFlux from .modules.autoencoder import AutoEncoder, AutoEncoderParams from .modules.conditioner import HFEmbedder from .annotator.dwpose import DWposeDetector from .annotator.mlsd import MLSDdetector from .annotator.canny import CannyDetector from .annotator.midas import MidasDetector from .annotator.hed import HEDdetector from .annotator.tile import TileDetector def load_safetensors(path): tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors def get_lora_rank(checkpoint): for k in checkpoint.keys(): if k.endswith(".down.weight"): return checkpoint[k].shape[0] def load_checkpoint(local_path, repo_id, name): if local_path is not None: if '.safetensors' in local_path: print("Loading .safetensors checkpoint...") checkpoint = load_safetensors(local_path) else: print("Loading checkpoint...") checkpoint = torch.load(local_path, map_location='cpu') elif repo_id is not None and name is not None: print("Loading checkpoint from repo id...") checkpoint = load_from_repo_id(repo_id, name) else: raise ValueError( "LOADING ERROR: you must specify local_path or repo_id with name in HF to download" ) return checkpoint def c_crop(image): width, height = image.size new_size = min(width, height) left = (width - new_size) / 2 top = (height - new_size) / 2 right = (width + new_size) / 2 bottom = (height + new_size) / 2 return image.crop((left, top, right, bottom)) class Annotator: def __init__(self, name: str, device: str): if name == "canny": processor = CannyDetector() elif name == "openpose": processor = DWposeDetector(device) elif name == "depth": processor = MidasDetector() elif name == "hed": processor = HEDdetector() elif name == "hough": processor = MLSDdetector() elif name == "tile": processor = TileDetector() self.name = name self.processor = processor def __call__(self, image: Image, width: int, height: int): image = c_crop(image) image = image.resize((width, height)) image = np.array(image) if self.name == "canny": result = self.processor(image, low_threshold=100, high_threshold=200) elif self.name == "hough": result = self.processor(image, thr_v=0.05, thr_d=5) elif self.name == "depth": result = self.processor(image) result, _ = result else: result = self.processor(image) if result.ndim != 3: result = result[:, :, None] result = np.concatenate([result, result, result], axis=2) return result @dataclass class ModelSpec: params: FluxParams ae_params: AutoEncoderParams ckpt_path: str | None ae_path: str | None repo_id: str | None repo_flow: str | None repo_ae: str | None repo_id_ae: str | None configs = { "flux-dev": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-fp8": ModelSpec( repo_id="XLabs-AI/flux-dev-fp8", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux-dev-fp8.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV_FP8"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-schnell": ModelSpec( repo_id="black-forest-labs/FLUX.1-schnell", repo_id_ae="black-forest-labs/FLUX.1-dev", repo_flow="flux1-schnell.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_SCHNELL"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=False, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), } def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_from_repo_id(repo_id, checkpoint_name): ckpt_path = hf_hub_download(repo_id, checkpoint_name) sd = load_sft(ckpt_path, device='cpu') return sd def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params).to(torch.bfloat16) if ckpt_path is not None: print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return model def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params) if ckpt_path is not None: print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return model def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) json_path = hf_hub_download(configs[name].repo_id, 'flux_dev_quantization_map.json') model = Flux(configs[name].params).to(torch.bfloat16) print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device='cpu') with open(json_path, "r") as f: quantization_map = json.load(f) print("Start a quantization process...") requantize(model, sd, quantization_map, device=device) print("Model is quantized!") return model def load_controlnet(name, device, transformer=None): with torch.device(device): controlnet = ControlNetFlux(configs[name].params) if transformer is not None: controlnet.load_state_dict(transformer.state_dict(), strict=False) return controlnet def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: # max length 64, 128, 256 and 512 should work (if your sequence is short enough) return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: ckpt_path = configs[name].ae_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_ae is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id_ae, configs[name].repo_ae) # Loading the autoencoder print("Init AE") with torch.device("meta" if ckpt_path is not None else device): ae = AutoEncoder(configs[name].ae_params) if ckpt_path is not None: sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return ae class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(WATERMARK_BITS) self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ Adds a predefined watermark to the input image Args: image: ([N,] B, RGB, H, W) in range [-1, 1] Returns: same as input but watermarked """ image = 0.5 * image + 0.5 squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # watermarking libary expects input as cv2 BGR format for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( image.device ) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: image = image[0] image = 2 * image - 1 return image # A fixed 48-bit message that was choosen at random WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]