import argparse
from pathlib import Path
from typing import Dict
import safetensors.torch
import torch
import json
import shutil


def load_text_encoder(index_path: Path) -> Dict:
    with open(index_path, "r") as f:
        index: Dict = json.load(f)

    loaded_tensors = {}
    for part_file in set(index.get("weight_map", {}).values()):
        tensors = safetensors.torch.load_file(
            index_path.parent / part_file, device="cpu"
        )
        for tensor_name in tensors:
            loaded_tensors[tensor_name] = tensors[tensor_name]

    return loaded_tensors


def convert_unet(unet: Dict, add_prefix=True) -> Dict:
    if add_prefix:
        return {"model.diffusion_model." + key: value for key, value in unet.items()}
    return unet


def convert_vae(vae_path: Path, add_prefix=True) -> Dict:
    state_dict = torch.load(vae_path / "autoencoder.pth", weights_only=True)
    stats_path = vae_path / "per_channel_statistics.json"
    if stats_path.exists():
        with open(stats_path, "r") as f:
            data = json.load(f)
        transposed_data = list(zip(*data["data"]))
        data_dict = {
            f"{'vae.' if add_prefix else ''}per_channel_statistics.{col}": torch.tensor(
                vals
            )
            for col, vals in zip(data["columns"], transposed_data)
        }
    else:
        data_dict = {}

    result = {
        ("vae." if add_prefix else "") + key: value for key, value in state_dict.items()
    }
    result.update(data_dict)
    return result


def convert_encoder(encoder: Dict) -> Dict:
    return {
        "text_encoders.t5xxl.transformer." + key: value
        for key, value in encoder.items()
    }


def save_config(config_src: str, config_dst: str):
    shutil.copy(config_src, config_dst)


def load_vae_config(vae_path: Path) -> str:
    config_path = vae_path / "config.json"
    if not config_path.exists():
        raise FileNotFoundError(f"VAE config file {config_path} not found.")
    return str(config_path)


def main(
    unet_path: str,
    vae_path: str,
    out_path: str,
    mode: str,
    unet_config_path: str = None,
    scheduler_config_path: str = None,
) -> None:
    unet = convert_unet(
        torch.load(unet_path, weights_only=True), add_prefix=(mode == "single")
    )

    # Load VAE from directory and config
    vae = convert_vae(Path(vae_path), add_prefix=(mode == "single"))
    vae_config_path = load_vae_config(Path(vae_path))

    if mode == "single":
        result = {**unet, **vae}
        safetensors.torch.save_file(result, out_path)
    elif mode == "separate":
        # Create directories for unet, vae, and scheduler
        unet_dir = Path(out_path) / "unet"
        vae_dir = Path(out_path) / "vae"
        scheduler_dir = Path(out_path) / "scheduler"

        unet_dir.mkdir(parents=True, exist_ok=True)
        vae_dir.mkdir(parents=True, exist_ok=True)
        scheduler_dir.mkdir(parents=True, exist_ok=True)

        # Save unet and vae safetensors with the name diffusion_pytorch_model.safetensors
        safetensors.torch.save_file(
            unet, unet_dir / "unet_diffusion_pytorch_model.safetensors"
        )
        safetensors.torch.save_file(
            vae, vae_dir / "vae_diffusion_pytorch_model.safetensors"
        )

        # Save config files for unet, vae, and scheduler
        if unet_config_path:
            save_config(unet_config_path, unet_dir / "config.json")
        if vae_config_path:
            save_config(vae_config_path, vae_dir / "config.json")
        if scheduler_config_path:
            save_config(scheduler_config_path, scheduler_dir / "scheduler_config.json")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--unet_path", "-u", type=str, default="unet/ema-002.pt")
    parser.add_argument("--vae_path", "-v", type=str, default="vae/")
    parser.add_argument("--out_path", "-o", type=str, default="xora.safetensors")
    parser.add_argument(
        "--mode",
        "-m",
        type=str,
        choices=["single", "separate"],
        default="single",
        help="Choose 'single' for the original behavior, 'separate' to save unet and vae separately.",
    )
    parser.add_argument(
        "--unet_config_path",
        type=str,
        help="Path to the UNet config file (for separate mode)",
    )
    parser.add_argument(
        "--scheduler_config_path",
        type=str,
        help="Path to the Scheduler config file (for separate mode)",
    )

    args = parser.parse_args()
    main(**args.__dict__)