import json
from glob import glob
from pathlib import Path

import tyro

def raw_params_to_readable(params: int) -> str:
    return f"{params/1e9:.1f}B"

def main(results_dir: Path, output_file: Path) -> None:
    output_file.parent.mkdir(parents=True, exist_ok=True)
    print(f"{results_dir} -> {output_file}")

    models = {}
    for model_dir in sorted(glob(f"{results_dir}/*/*")):
        model_name = "/".join(model_dir.split("/")[-2:])
        print(f"  {model_name}")
        result_file_cand = glob(f"{model_dir}/bs1+*+steps25+results.json")
        assert len(result_file_cand) == 1, model_name
        results_data = json.load(open(result_file_cand[0]))
        denosing_module_name = "unet" if "unet" in results_data["num_parameters"] else "transformer"
        model_info = dict(
            url=f"https://huggingface.co/{model_name}",
            nickname=model_name.split("/")[-1].replace("-", " ").title(),
            total_params=raw_params_to_readable(sum(results_data["num_parameters"].values())),
            denoising_params=raw_params_to_readable(results_data["num_parameters"][denosing_module_name]),
            resolution="NA",
        )
        assert model_name not in models
        models[model_name] = model_info

    json.dump(models, open(output_file, "w"), indent=2)


if __name__ == "__main__":
    tyro.cli(main)