Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import json
|
|
4 |
|
5 |
component_filter = ["scheduler", "safety_checker", "tokenizer"]
|
6 |
|
|
|
7 |
def format_size(num: int) -> str:
|
8 |
"""Format size in bytes into a human-readable string.
|
9 |
Taken from https://stackoverflow.com/a/1094933
|
@@ -15,6 +16,7 @@ def format_size(num: int) -> str:
|
|
15 |
num_f /= 1000.0
|
16 |
return f"{num_f:.1f}Y"
|
17 |
|
|
|
18 |
def format_output(pipeline_id, memory_mapping):
|
19 |
markdown_str = f"## {pipeline_id}\n"
|
20 |
if memory_mapping:
|
@@ -22,12 +24,14 @@ def format_output(pipeline_id, memory_mapping):
|
|
22 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
23 |
return markdown_str
|
24 |
|
|
|
25 |
def load_model_index(pipeline_id, token=None, revision=None):
|
26 |
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
|
27 |
with open(index_path, "r") as f:
|
28 |
index_dict = json.load(f)
|
29 |
return index_dict
|
30 |
|
|
|
31 |
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
|
32 |
if token == "":
|
33 |
token = None
|
@@ -48,17 +52,22 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
48 |
index_filter.extend(["_class_name", "_diffusers_version"])
|
49 |
for current_component in index_dict:
|
50 |
if current_component not in index_filter:
|
51 |
-
current_component_fileobjs =
|
52 |
if current_component_fileobjs:
|
53 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
54 |
-
condition =
|
|
|
|
|
|
|
|
|
55 |
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
|
56 |
if not variant_present_with_extension:
|
57 |
-
raise ValueError(
|
|
|
|
|
58 |
else:
|
59 |
raise ValueError(f"Problem with {current_component}.")
|
60 |
|
61 |
-
|
62 |
# Handle text encoder separately when it's sharded.
|
63 |
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
|
64 |
component_wise_memory = {}
|
@@ -99,4 +108,37 @@ def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=No
|
|
99 |
print(selected_file.rfilename)
|
100 |
component_wise_memory[component] = selected_file.size
|
101 |
|
102 |
-
return format_output(pipeline_id, component_wise_memory)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
component_filter = ["scheduler", "safety_checker", "tokenizer"]
|
6 |
|
7 |
+
|
8 |
def format_size(num: int) -> str:
|
9 |
"""Format size in bytes into a human-readable string.
|
10 |
Taken from https://stackoverflow.com/a/1094933
|
|
|
16 |
num_f /= 1000.0
|
17 |
return f"{num_f:.1f}Y"
|
18 |
|
19 |
+
|
20 |
def format_output(pipeline_id, memory_mapping):
|
21 |
markdown_str = f"## {pipeline_id}\n"
|
22 |
if memory_mapping:
|
|
|
24 |
markdown_str += f"* {component}: {format_size(memory)}\n"
|
25 |
return markdown_str
|
26 |
|
27 |
+
|
28 |
def load_model_index(pipeline_id, token=None, revision=None):
|
29 |
index_path = hf_hub_download(repo_id=pipeline_id, filename="model_index.json", revision=revision, token=token)
|
30 |
with open(index_path, "r") as f:
|
31 |
index_dict = json.load(f)
|
32 |
return index_dict
|
33 |
|
34 |
+
|
35 |
def get_component_wise_memory(pipeline_id, token=None, variant=None, revision=None, extension=".safetensors"):
|
36 |
if token == "":
|
37 |
token = None
|
|
|
52 |
index_filter.extend(["_class_name", "_diffusers_version"])
|
53 |
for current_component in index_dict:
|
54 |
if current_component not in index_filter:
|
55 |
+
current_component_fileobjs = list(filter(lambda x: current_component in x.rfilename, files_in_repo))
|
56 |
if current_component_fileobjs:
|
57 |
current_component_filenames = [fileobj.rfilename for fileobj in current_component_fileobjs]
|
58 |
+
condition = (
|
59 |
+
lambda filename: extension in filename and variant in filename
|
60 |
+
if variant is not None
|
61 |
+
else lambda filename: extension in filename
|
62 |
+
)
|
63 |
variant_present_with_extension = any(condition(filename) for filename in current_component_filenames)
|
64 |
if not variant_present_with_extension:
|
65 |
+
raise ValueError(
|
66 |
+
f"Requested extension ({extension}) and variant ({variant}) not present for {current_component}. Available files for this component:\n{current_component_filenames}."
|
67 |
+
)
|
68 |
else:
|
69 |
raise ValueError(f"Problem with {current_component}.")
|
70 |
|
|
|
71 |
# Handle text encoder separately when it's sharded.
|
72 |
is_text_encoder_shared = any(".index.json" in file_obj.rfilename for file_obj in files_in_repo)
|
73 |
component_wise_memory = {}
|
|
|
108 |
print(selected_file.rfilename)
|
109 |
component_wise_memory[component] = selected_file.size
|
110 |
|
111 |
+
return format_output(pipeline_id, component_wise_memory)
|
112 |
+
|
113 |
+
|
114 |
+
gr.Interface(
|
115 |
+
title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
|
116 |
+
description="Sizes will be reported in GB. Pipelines containing text encoders with sharded checkpoints are also supported (PixArt-Alpha, for example) 🤗",
|
117 |
+
fn=get_component_wise_memory,
|
118 |
+
inputs=[
|
119 |
+
gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
|
120 |
+
gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
|
121 |
+
gr.components.Dropdown(
|
122 |
+
[
|
123 |
+
"fp32",
|
124 |
+
"fp16",
|
125 |
+
],
|
126 |
+
label="variant",
|
127 |
+
info="Precision to use for calculation.",
|
128 |
+
),
|
129 |
+
gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
|
130 |
+
gr.components.Dropdown(
|
131 |
+
[".bin", ".safetensors"],
|
132 |
+
label="extension",
|
133 |
+
info="Extension to use.",
|
134 |
+
),
|
135 |
+
],
|
136 |
+
outputs=[gr.Markdown(label="Output")],
|
137 |
+
examples=[
|
138 |
+
["runwayml/stable-diffusion-v1-5", None, "fp32", None, ".safetensors"],
|
139 |
+
["stabilityai/stable-diffusion-xl-base-1.0", None, "fp16", None, ".safetensors"],
|
140 |
+
["PixArt-alpha/PixArt-XL-2-1024-MS", None, "fp32", None, ".safetensors"],
|
141 |
+
],
|
142 |
+
theme=gr.themes.Soft(),
|
143 |
+
allow_flagging=False,
|
144 |
+
).launch(show_error=True)
|