|
import argparse |
|
from pathlib import Path |
|
from typing import Dict |
|
import safetensors.torch |
|
import torch |
|
import json |
|
import shutil |
|
|
|
|
|
def load_text_encoder(index_path: Path) -> Dict: |
|
with open(index_path, "r") as f: |
|
index: Dict = json.load(f) |
|
|
|
loaded_tensors = {} |
|
for part_file in set(index.get("weight_map", {}).values()): |
|
tensors = safetensors.torch.load_file( |
|
index_path.parent / part_file, device="cpu" |
|
) |
|
for tensor_name in tensors: |
|
loaded_tensors[tensor_name] = tensors[tensor_name] |
|
|
|
return loaded_tensors |
|
|
|
|
|
def convert_unet(unet: Dict, add_prefix=True) -> Dict: |
|
if add_prefix: |
|
return {"model.diffusion_model." + key: value for key, value in unet.items()} |
|
return unet |
|
|
|
|
|
def convert_vae(vae_path: Path, add_prefix=True) -> Dict: |
|
state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True) |
|
stats_path = vae_path / "per_channel_statistics.json" |
|
if stats_path.exists(): |
|
with open(stats_path, "r") as f: |
|
data = json.load(f) |
|
transposed_data = list(zip(*data["data"])) |
|
data_dict = { |
|
f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor( |
|
vals |
|
) |
|
for col, vals in zip(data["columns"], transposed_data) |
|
} |
|
else: |
|
data_dict = {} |
|
|
|
result = { |
|
("vae." if add_prefix else "") + key: value for key, value in state_dict.items() |
|
} |
|
result.update(data_dict) |
|
return result |
|
|
|
|
|
def convert_encoder(encoder: Dict) -> Dict: |
|
return { |
|
"text_encoders.t5xxl.transformer." + key: value |
|
for key, value in encoder.items() |
|
} |
|
|
|
|
|
def save_config(config_src: str, config_dst: str): |
|
shutil.copy(config_src, config_dst) |
|
|
|
|
|
def load_vae_config(vae_path: Path) -> str: |
|
config_path = vae_path / "config.json" |
|
if not config_path.exists(): |
|
raise FileNotFoundError(f"VAE config file {config_path} not found.") |
|
return str(config_path) |
|
|
|
|
|
def main( |
|
unet_path: str, |
|
vae_path: str, |
|
out_path: str, |
|
mode: str, |
|
unet_config_path: str = None, |
|
scheduler_config_path: str = None, |
|
) -> None: |
|
unet = convert_unet( |
|
torch.load(unet_path, weights_only=True), add_prefix=(mode == "single") |
|
) |
|
|
|
|
|
vae = convert_vae(Path(vae_path), add_prefix=(mode == "single")) |
|
vae_config_path = load_vae_config(Path(vae_path)) |
|
|
|
if mode == "single": |
|
result = {**unet, **vae} |
|
safetensors.torch.save_file(result, out_path) |
|
elif mode == "separate": |
|
|
|
unet_dir = Path(out_path) / "unet" |
|
vae_dir = Path(out_path) / "vae" |
|
scheduler_dir = Path(out_path) / "scheduler" |
|
|
|
unet_dir.mkdir(parents=True, exist_ok=True) |
|
vae_dir.mkdir(parents=True, exist_ok=True) |
|
scheduler_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
safetensors.torch.save_file( |
|
unet, unet_dir / "diffusion_pytorch_model.safetensors" |
|
) |
|
safetensors.torch.save_file( |
|
vae, vae_dir / "diffusion_pytorch_model.safetensors" |
|
) |
|
|
|
|
|
if unet_config_path: |
|
save_config(unet_config_path, unet_dir / "config.json") |
|
if vae_config_path: |
|
save_config(vae_config_path, vae_dir / "config.json") |
|
if scheduler_config_path: |
|
save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt") |
|
parser.add_argument("--vae_path", "-v", type=str, default="vae/") |
|
parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors") |
|
parser.add_argument( |
|
"--mode", |
|
"-m", |
|
type=str, |
|
choices=["single", "separate"], |
|
default="single", |
|
help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.", |
|
) |
|
parser.add_argument( |
|
"--unet_config_path", |
|
type=str, |
|
help="Path to the UNet config file (for separate mode)", |
|
) |
|
parser.add_argument( |
|
"--scheduler_config_path", |
|
type=str, |
|
help="Path to the Scheduler config file (for separate mode)", |
|
) |
|
|
|
args = parser.parse_args() |
|
main(**args.__dict__) |
|
|