sayakpaul HF staff commited on
Commit
41dfa78
1 Parent(s): c2058c7

beautified the app

Browse files
Files changed (1) hide show
  1. app.py +78 -49
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from huggingface_hub import model_info, hf_hub_download
2
  import gradio as gr
3
  import json
4
 
@@ -54,11 +54,11 @@ def format_output(pipeline_id, memory_mapping, variant=None, controlnet_mapping=
54
  for component, memory in memory_mapping.items():
55
  markdown_str += f"* {component}: {format_size(memory)}\n"
56
  if controlnet_mapping:
57
- markdown_str += "\n## ControlNet(s)\n"
58
  for controlnet_id, memory in controlnet_mapping.items():
59
  markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
60
  if t2i_adapter_mapping:
61
- markdown_str += "\n## T2I-Adapters(s)\n"
62
  for t2_adapter_id, memory in t2i_adapter_mapping.items():
63
  markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"
64
 
@@ -202,49 +202,78 @@ def get_component_wise_memory(
202
  return format_output(pipeline_id, component_wise_memory, variant, controlnet_mapping, t2_adapter_mapping)
203
 
204
 
205
- with gr.Interface(
206
- title="Compute component-wise memory of a 🧨 Diffusers pipeline.",
207
- description="Pipelines containing text encoders with sharded checkpoints are also supported"
208
- " (PixArt-Alpha, for example) 🤗 See instructions below the form on how to pass"
209
- " `controlnet_id` or `t2_adapter_id`.",
210
- article=ARTICLE,
211
- fn=get_component_wise_memory,
212
- inputs=[
213
- gr.components.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5"),
214
- gr.components.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny"),
215
- gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
216
- gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private/gated repositories."),
217
- gr.components.Radio(
218
- ALLOWED_VARIANTS,
219
- label="variant",
220
- info="Precision to use for calculation.",
221
- ),
222
- gr.components.Textbox(lines=1, label="revision", info="Repository revision to use."),
223
- gr.components.Radio(
224
- [".bin", ".safetensors"],
225
- label="extension",
226
- info="Extension to use.",
227
- ),
228
- ],
229
- outputs=[gr.Markdown(label="Output")],
230
- examples=[
231
- ["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
232
- ["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
233
- ["runwayml/stable-diffusion-v1-5", "lllyasviel/sd-controlnet-canny", None, None, "fp32", None, ".safetensors"],
234
- [
235
- "stabilityai/stable-diffusion-xl-base-1.0",
236
- None,
237
- "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
238
- None,
239
- "fp16",
240
- None,
241
- ".safetensors",
242
- ],
243
- ["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
244
- ["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
245
- ],
246
- theme=gr.themes.Soft(),
247
- allow_flagging="never",
248
- cache_examples=False,
249
- ) as demo:
250
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download, model_info
2
  import gradio as gr
3
  import json
4
 
 
54
  for component, memory in memory_mapping.items():
55
  markdown_str += f"* {component}: {format_size(memory)}\n"
56
  if controlnet_mapping:
57
+ markdown_str += f"\n## ControlNet(s) ({variant})\n"
58
  for controlnet_id, memory in controlnet_mapping.items():
59
  markdown_str += f"* {controlnet_id}: {format_size(memory)}\n"
60
  if t2i_adapter_mapping:
61
+ markdown_str += f"\n## T2I-Adapters(s) ({variant})\n"
62
  for t2_adapter_id, memory in t2i_adapter_mapping.items():
63
  markdown_str += f"* {t2_adapter_id}: {format_size(memory)}\n"
64
 
 
202
  return format_output(pipeline_id, component_wise_memory, variant, controlnet_mapping, t2_adapter_mapping)
203
 
204
 
205
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
206
+ with gr.Column():
207
+ gr.Markdown(
208
+ """<img src="https://huggingface.co/spaces/hf-accelerate/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="150" height="175"><h1>🧨 Diffusers Pipeline Memory Calculator</h1>
209
+ This tool will help you to gauge the memory requirements of a Diffusers pipeline. Pipelines containing text encoders with sharded checkpoints are also supported
210
+ (PixArt-Alpha, for example) 🤗 See instructions below the form on how to pass `controlnet_id` or `t2_adapter_id`. When performing inference, expect to add up to an
211
+ additional 20% to this as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/).
212
+ """
213
+ )
214
+ out_text = gr.Markdown()
215
+ with gr.Row():
216
+ pipeline_id = gr.Textbox(lines=1, label="pipeline_id", info="Example: runwayml/stable-diffusion-v1-5")
217
+ with gr.Row():
218
+ controlnet_id = gr.Textbox(lines=1, label="controlnet_id", info="Example: lllyasviel/sd-controlnet-canny")
219
+ t2i_adapter_id = gr.Textbox(
220
+ lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"
221
+ )
222
+ with gr.Row():
223
+ token = gr.Textbox(lines=1, label="hf_token", info="Pass this in case of private/gated repositories.")
224
+ variant = gr.Radio(
225
+ ALLOWED_VARIANTS,
226
+ label="variant",
227
+ info="Precision to use for calculation.",
228
+ )
229
+ revision = gr.Textbox(lines=1, label="revision", info="Repository revision to use.")
230
+ extension = gr.Radio(
231
+ [".bin", ".safetensors"],
232
+ label="extension",
233
+ info="Extension to use.",
234
+ )
235
+ with gr.Row():
236
+ btn = gr.Button("Calculate Memory Usage")
237
+
238
+ gr.Markdown("## Examples")
239
+ gr.Examples(
240
+ [
241
+ ["runwayml/stable-diffusion-v1-5", None, None, None, "fp32", None, ".safetensors"],
242
+ ["PixArt-alpha/PixArt-XL-2-1024-MS", None, None, None, "fp32", None, ".safetensors"],
243
+ [
244
+ "runwayml/stable-diffusion-v1-5",
245
+ "lllyasviel/sd-controlnet-canny",
246
+ None,
247
+ None,
248
+ "fp32",
249
+ None,
250
+ ".safetensors",
251
+ ],
252
+ [
253
+ "stabilityai/stable-diffusion-xl-base-1.0",
254
+ None,
255
+ "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
256
+ None,
257
+ "fp16",
258
+ None,
259
+ ".safetensors",
260
+ ],
261
+ ["stabilityai/stable-cascade", None, None, None, "bf16", None, ".safetensors"],
262
+ ["Deci/DeciDiffusion-v2-0", None, None, None, "fp32", None, ".safetensors"],
263
+ ],
264
+ [pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension],
265
+ out_text,
266
+ get_component_wise_memory,
267
+ cache_examples=False,
268
+ )
269
+
270
+ gr.Markdown(ARTICLE)
271
+
272
+ btn.click(
273
+ get_component_wise_memory,
274
+ inputs=[pipeline_id, controlnet_id, t2i_adapter_id, token, variant, revision, extension],
275
+ outputs=[out_text],
276
+ api_name=False,
277
+ )
278
+
279
+ demo.launch(show_error=True)