from huggingface_hub import model_info, hf_hub_download import gradio as gr import json COMPONENT_FILTER = [ "scheduler", "feature_extractor", "tokenizer", "tokenizer_2", "_class_name", "_diffusers_version", ] ARTICLE = """ ## Notes on how to use the `controlnet_id` and `t2i_adapter_id` fields Both `controlnet_id` and `t2i_adapter_id` fields support passing multiple checkpoint ids, e.g., "thibaud/controlnet-openpose-sdxl-1.0,diffusers/controlnet-canny-sdxl-1.0". For `t2i_adapter_id`, this could be like - "TencentARC/t2iadapter_keypose_sd14v1,TencentARC/t2iadapter_depth_sd14v1". Users should take care of passing the underlying base `pipeline_id` appropriately. For example, passing `pipeline_id` as "runwayml/stable-diffusion-v1-5" and `controlnet_id` as "thibaud/controlnet-openpose-sdxl-1.0" won't result in an error but these two things aren't meant to compatible. You should pass a `controlnet_id` that is compatible with "runwayml/stable-diffusion-v1-5". For further clarification on this topic, feel free to open a [discussion](https://huggingface.co/spaces/diffusers/compute-pipeline-size/discussions). ๐Ÿ“” Also, note that `revision` field is only reserved for `pipeline_id`. It won't have any effect on the `controlnet_id` or `t2i_adapter_id`. """ ALLOWED_VARIANTS = ["fp32", "fp16", "bf16"] def format_size(num: int) -> str: """Format size in bytes into a human-readable string. Taken from https://stackoverflow.com/a/1094933 """ num_f = float(num) for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: if abs(num_f) < 1000.0: return f"{num_f:3.1f}{unit}" num_f /= 1000.0 return f"{num_f:.1f}Y" def format_output(pipeline_id, memory_mapping, controlnet_mapping=None, t2i_adapter_mapping=None): markdown_str = f"## {pipeline_id}\n" if memory_mapping: for component, memory in memory_mapping.items(): markdown_str += f"* {component}: {format_size(memory)}\n" if controlnet_mapping: markdown_str += "\n## ControlNet(s)\n" for controlnet_id, memory in controlnet_mapping.items(): markdown_str += f"* {controlnet_id}: {format_size(memory)}\n" if t2i_adapter_mapping: markdown_str += "\n## T2I-Adapters(s)\n" for t2_adapter_id, memory in t2i_adapter_mapping.items(): markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n" return markdown_str def load_model_index(pipeline_id, token=None, revision=None): index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token) with open(index_path, "r") as f: index_dict = json.load(f) return index_dict def get_individual_model_memory(id, token, variant, extension): files_in_repo = model_info(id, token=token, files_metadata=True).siblings candidates = [x for x in files_in_repo if extension in x.rfilename] if variant: candidate = list(filter(lambda x: variant in x.rfilename, candidates))[0] else: candidate = list(filter(lambda x: all(var not in x.rfilename for var in ALLOWED_VARIANTS[1:]), candidates))[0] return candidate.size def get_component_wise_memory( pipeline_id, controlnet_id=None, t2i_adapter_id=None, token=None, variant=None, revision=None, extension=".safetensors", ): if controlnet_id == "": controlnet_id = None if t2i_adapter_id == "": t2i_adapter_id = None if controlnet_id and t2i_adapter_id: raise ValueError("Both `controlnet_id` and `t2i_adapter_id` cannot be provided.") if token == "": token = None if revision == "": revision = None if variant == "fp32": variant = None # Handle ControlNet and T2I-Adapter. controlnet_mapping = t2_adapter_mapping = None if controlnet_id is not None: controlnet_id = controlnet_id.split(",") controlnet_sizes = [ get_individual_model_memory(id_, token=token, variant=variant, extension=extension) for id_ in controlnet_id ] controlnet_mapping = dict(zip(controlnet_id, controlnet_sizes)) elif t2i_adapter_id is not None: t2i_adapter_id = t2i_adapter_id.split(",") t2i_adapter_sizes = [ get_individual_model_memory(id_, token=token, variant=variant, extension=extension) for id_ in t2i_adapter_id ] t2_adapter_mapping = dict(zip(t2i_adapter_id, t2i_adapter_sizes)) print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}") # Load pipeline metadata. files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings index_dict = load_model_index(pipeline_id, token=token, revision=revision) # Check if all the concerned components have the checkpoints in # the requested "variant" and "extension". print(f"Index dict: {index_dict}") for current_component in index_dict: if ( current_component not in COMPONENT_FILTER and isinstance(index_dict[current_component], list) and len(index_dict[current_component]) == 2 ): current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo)) if current_component_fileobjs: current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs] condition = ( # noqa: E731 lambda filename: extension in filename and variant in filename if variant is not None else lambda filename: extension in filename ) variant_present_with_extension = any(condition(filename) for filename in current_component_filenames) if not variant_present_with_extension: formatted_filenames = ", ".join(current_component_filenames) raise ValueError( f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}." f" Available files for this component: {formatted_filenames}." ) else: raise ValueError(f"Problem with {current_component}.") # Handle text encoder separately when it's sharded. is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo) component_wise_memory = {} if is_text_encoder_shared: for current_file in files_in_repo: if "text_encoder" in current_file.rfilename: if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension): if variant is not None and variant in current_file.rfilename: selected_file = current_file else: selected_file = current_file if "text_encoder" not in component_wise_memory: component_wise_memory["text_encoder"] = selected_file.size else: component_wise_memory["text_encoder"] += selected_file.size # Handle pipeline components. if is_text_encoder_shared: COMPONENT_FILTER.append("text_encoder") for current_file in files_in_repo: if all(substring not in current_file.rfilename for substring in COMPONENT_FILTER): is_folder = len(current_file.rfilename.split("/")) == 2 if is_folder and current_file.rfilename.split("/")[0] in index_dict: selected_file = None if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension): component = current_file.rfilename.split("/")[0] if ( variant is not None and variant in current_file.rfilename and "ema" not in current_file.rfilename ): selected_file = current_file elif variant is None and "ema" not in current_file.rfilename: selected_file = current_file if selected_file is not None: component_wise_memory[component] = selected_file.size return format_output(pipeline_id, component_wise_memory, controlnet_mapping, t2_adapter_mapping) with gr.Interface( title="Compute component-wise memory of a ๐Ÿงจ Diffusers pipeline.", description="Pipelines containing text encoders with sharded checkpoints are also supported" " (PixArt-Alpha, for example) ๐Ÿค— See instructions below the form on how to pass" " `controlnet_id` or `t2_adapter_id`.", article=ARTICLE, fn=get_component_wise_memory, inputs=[ gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"), gr.components.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny"), gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"), gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."), gr.components.Radio( ALLOWED_VARIANTS, label="variant", info="Precision to use for calculation.", ), gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."), gr.components.Radio( [".bin", ".safetensors"], label="extension", info="Extension to use.", ), ], outputs=[gr.Markdown(label="Output")], examples=[ ["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"], ["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"], ["runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", None, None, "fp32", None, ".safetensors"], [ "stabilityai/stable-diffusion-xl-base-1.0", None, "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0", None, "fp16", None, ".safetensors", ], ["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"], ["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"], ], theme=gr.themes.Soft(), allow_flagging="never", cache_examples=False, ) as demo: demo.launch(show_error=True)