from huggingface_hub import model_info, hf_hub_download
import gradio as gr
import json

COMPONENT_FILTER = [
    "scheduler",
    "feature_extractor",
    "tokenizer",
    "tokenizer_2",
    "_class_name",
    "_diffusers_version",
]

ARTICLE = """
## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields

Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids,
e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For 
`t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1".

Users should take care of passing the underlying base `pipeline_id` appropriately. For example,
passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0"
won't result in an error but these two things aren't meant to compatible. You should pass
a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5".

For further clarification on this topic, feel free to open a [discussion](https://huggingface.co/spaces/diffusers/compute-pipeline-size/discussions).

๐Ÿ“” Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the
`controlnet_id` or `t2i_adapter_id`.
"""

ALLOWED_VARIANTS = ["fp32", "fp16", "bf16"]


def format_size(num: int) -> str:
    """Format size in bytes into a human-readable string.
    Taken from https://stackoverflow.com/a/1094933
    """
    num_f = float(num)
    for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
        if abs(num_f) < 1000.0:
            return f"{num_f:3.1f}{unit}"
        num_f /= 1000.0
    return f"{num_f:.1f}Y"


def format_output(pipeline_id, memory_mapping, controlnet_mapping=None, t2i_adapter_mapping=None):
    markdown_str = f"## {pipeline_id}\n"

    if memory_mapping:
        for component, memory in memory_mapping.items():
            markdown_str += f"* {component}: {format_size(memory)}\n"
    if controlnet_mapping:
        markdown_str += "\n## ControlNet(s)\n"
        for controlnet_id, memory in controlnet_mapping.items():
            markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
    if t2i_adapter_mapping:
        markdown_str += "\n## T2I-Adapters(s)\n"
        for t2_adapter_id, memory in t2i_adapter_mapping.items():
            markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"

    return markdown_str


def load_model_index(pipeline_id, token=None, revision=None):
    index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
    with open(index_path, "r") as f:
        index_dict = json.load(f)
    return index_dict


def get_individual_model_memory(id, token, variant, extension):
    files_in_repo = model_info(id, token=token, files_metadata=True).siblings
    candidates = [x for x in files_in_repo if extension in x.rfilename]
    if variant:
        candidate = list(filter(lambda x: variant in x.rfilename, candidates))[0]
    else:
        candidate = list(filter(lambda x: all(var not in x.rfilename for var in ALLOWED_VARIANTS[1:]), candidates))[0]
    return candidate.size


def get_component_wise_memory(
    pipeline_id,
    controlnet_id=None,
    t2i_adapter_id=None,
    token=None,
    variant=None,
    revision=None,
    extension=".safetensors",
):
    if controlnet_id == "":
        controlnet_id = None

    if t2i_adapter_id == "":
        t2i_adapter_id = None

    if controlnet_id and t2i_adapter_id:
        raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.")

    if token == "":
        token = None

    if revision == "":
        revision = None

    if variant == "fp32":
        variant = None

    # Handle ControlNet and T2I-Adapter.
    controlnet_mapping = t2_adapter_mapping = None
    if controlnet_id is not None:
        controlnet_id = controlnet_id.split(",")
        controlnet_sizes = [
            get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
            for id_ in controlnet_id
        ]
        controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes))
    elif t2i_adapter_id is not None:
        t2i_adapter_id = t2i_adapter_id.split(",")
        t2i_adapter_sizes = [
            get_individual_model_memory(id_, token=token, variant=variant, extension=extension)
            for id_ in t2i_adapter_id
        ]
        t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes))

    print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}")

    # Load pipeline metadata.
    files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings
    index_dict = load_model_index(pipeline_id, token=token, revision=revision)

    # Check if all the concerned components have the checkpoints in
    # the requested "variant" and "extension".
    print(f"Index dict: {index_dict}")
    for current_component in index_dict:
        if (
            current_component not in COMPONENT_FILTER
            and isinstance(index_dict[current_component], list)
            and len(index_dict[current_component]) == 2
        ):
            current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))

            if current_component_fileobjs:
                current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
                condition = (  # noqa: E731
                    lambda filename: extension in filename and variant in filename
                    if variant is not None
                    else lambda filename: extension in filename
                )
                variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
                if not variant_present_with_extension:
                    formatted_filenames = ", ".join(current_component_filenames)
                    raise ValueError(
                        f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}."
                        f" Available files for this component: {formatted_filenames}."
                    )
            else:
                raise ValueError(f"Problem with {current_component}.")

    # Handle text encoder separately when it's sharded.
    is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
    component_wise_memory = {}
    if is_text_encoder_shared:
        for current_file in files_in_repo:
            if "text_encoder" in current_file.rfilename:
                if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
                    if variant is not None and variant in current_file.rfilename:
                        selected_file = current_file
                    else:
                        selected_file = current_file
                    if "text_encoder" not in component_wise_memory:
                        component_wise_memory["text_encoder"] = selected_file.size
                    else:
                        component_wise_memory["text_encoder"] += selected_file.size

    # Handle pipeline components.
    if is_text_encoder_shared:
        COMPONENT_FILTER.append("text_encoder")

    for current_file in files_in_repo:
        if all(substring not in current_file.rfilename for substring in COMPONENT_FILTER):
            is_folder = len(current_file.rfilename.split("/")) == 2
            if is_folder and current_file.rfilename.split("/")[0] in index_dict:
                selected_file = None
                if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension):
                    component = current_file.rfilename.split("/")[0]
                    if (
                        variant is not None
                        and variant in current_file.rfilename
                        and "ema" not in current_file.rfilename
                    ):
                        selected_file = current_file
                    elif variant is None and "ema" not in current_file.rfilename:
                        selected_file = current_file

                    if selected_file is not None:
                        component_wise_memory[component] = selected_file.size

    return format_output(pipeline_id, component_wise_memory, controlnet_mapping, t2_adapter_mapping)


with gr.Interface(
    title="Compute component-wise memory of a ๐Ÿงจ Diffusers pipeline.",
    description="Pipelines containing text encoders with sharded checkpoints are also supported"
    " (PixArt-Alpha, for example) ๐Ÿค— See instructions below the form on how to pass"
    " `controlnet_id` or `t2_adapter_id`.",
    article=ARTICLE,
    fn=get_component_wise_memory,
    inputs=[
        gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
        gr.components.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny"),
        gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
        gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
        gr.components.Radio(
            ALLOWED_VARIANTS,
            label="variant",
            info="Precision to use for calculation.",
        ),
        gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
        gr.components.Radio(
            [".bin", ".safetensors"],
            label="extension",
            info="Extension to use.",
        ),
    ],
    outputs=[gr.Markdown(label="Output")],
    examples=[
        ["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
        ["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
        ["runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", None, None, "fp32", None, ".safetensors"],
        [
            "stabilityai/stable-diffusion-xl-base-1.0",
            None,
            "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
            None,
            "fp16",
            None,
            ".safetensors",
        ],
        ["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
        ["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
    ],
    theme=gr.themes.Soft(),
    allow_flagging="never",
    cache_examples=False,
) as demo:
    demo.launch(show_error=True)