|
from huggingface_hub import model_info, hf_hub_download |
|
import gradio as gr |
|
import json |
|
|
|
component_filter = ["scheduler", "safety_checker", "tokenizer"] |
|
|
|
|
|
def format_size(num: int) -> str: |
|
"""Format size in bytes into a human-readable string. |
|
Taken from https://stackoverflow.com/a/1094933 |
|
""" |
|
num_f = float(num) |
|
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: |
|
if abs(num_f) < 1000.0: |
|
return f"{num_f:3.1f}{unit}" |
|
num_f /= 1000.0 |
|
return f"{num_f:.1f}Y" |
|
|
|
|
|
def format_output(pipeline_id, memory_mapping): |
|
markdown_str = f"## {pipeline_id}\n" |
|
if memory_mapping: |
|
for component, memory in memory_mapping.items(): |
|
markdown_str += f"* {component}: {format_size(memory)}\n" |
|
return markdown_str |
|
|
|
|
|
def load_model_index(pipeline_id, token=None, revision=None): |
|
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token) |
|
with open(index_path, "r") as f: |
|
index_dict = json.load(f) |
|
return index_dict |
|
|
|
|
|
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"): |
|
if token == "": |
|
token = None |
|
|
|
if revision == "": |
|
revision = None |
|
|
|
if variant == "fp32": |
|
variant = None |
|
|
|
print(f"pipeline_id: {pipeline_id}, variant: {variant}, revision: {revision}, extension: {extension}") |
|
|
|
files_in_repo = model_info(pipeline_id, revision=revision, token=token, files_metadata=True).siblings |
|
index_dict = load_model_index(pipeline_id, token=token, revision=revision) |
|
|
|
|
|
index_filter = component_filter.copy() |
|
index_filter.extend(["_class_name", "_diffusers_version"]) |
|
for current_component in index_dict: |
|
if current_component not in index_filter: |
|
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo)) |
|
if current_component_fileobjs: |
|
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs] |
|
condition = ( |
|
lambda filename: extension in filename and variant in filename |
|
if variant is not None |
|
else lambda filename: extension in filename |
|
) |
|
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames) |
|
if not variant_present_with_extension: |
|
raise ValueError( |
|
f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}." |
|
) |
|
else: |
|
raise ValueError(f"Problem with {current_component}.") |
|
|
|
|
|
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo) |
|
component_wise_memory = {} |
|
if is_text_encoder_shared: |
|
for current_file in files_in_repo: |
|
if "text_encoder" in current_file.rfilename: |
|
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension): |
|
if variant is not None and variant in current_file.rfilename: |
|
selected_file = current_file |
|
else: |
|
selected_file = current_file |
|
if "text_encoder" not in component_wise_memory: |
|
component_wise_memory["text_encoder"] = selected_file.size |
|
else: |
|
component_wise_memory["text_encoder"] += selected_file.size |
|
|
|
|
|
if is_text_encoder_shared: |
|
component_filter.append("text_encoder") |
|
|
|
for current_file in files_in_repo: |
|
if all(substring not in current_file.rfilename for substring in component_filter): |
|
is_folder = len(current_file.rfilename.split("/")) == 2 |
|
if is_folder and current_file.rfilename.split("/")[0] in index_dict: |
|
selected_file = None |
|
if not current_file.rfilename.endswith(".json") and current_file.rfilename.endswith(extension): |
|
component = current_file.rfilename.split("/")[0] |
|
if ( |
|
variant is not None |
|
and variant in current_file.rfilename |
|
and "ema" not in current_file.rfilename |
|
): |
|
selected_file = current_file |
|
elif variant is None and "ema" not in current_file.rfilename: |
|
selected_file = current_file |
|
|
|
if selected_file is not None: |
|
print(selected_file.rfilename) |
|
component_wise_memory[component] = selected_file.size |
|
|
|
return format_output(pipeline_id, component_wise_memory) |
|
|
|
|
|
gr.Interface( |
|
title="Compute component-wise memory of a 🧨 Diffusers pipeline.", |
|
description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗", |
|
fn=get_component_wise_memory, |
|
inputs=[ |
|
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"), |
|
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."), |
|
gr.components.Dropdown( |
|
[ |
|
"fp32", |
|
"fp16", |
|
], |
|
label="variant", |
|
info="Precision to use for calculation.", |
|
), |
|
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."), |
|
gr.components.Dropdown( |
|
[".bin", ".safetensors"], |
|
label="extension", |
|
info="Extension to use.", |
|
), |
|
], |
|
outputs=[gr.Markdown(label="Output")], |
|
examples=[ |
|
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"], |
|
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"], |
|
["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"], |
|
], |
|
theme=gr.themes.Soft(), |
|
allow_flagging=False, |
|
).launch(show_error=True) |
|
|