sayakpaul HF staff commited on
Commit
0d1d63a
1 Parent(s): 3f6a1fe

reobustify

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -29,6 +29,8 @@ For further clarification on this topic, feel free to open a [discussion](https:
29
  `controlnet_id` or `t2i_adapter_id`.
30
  """
31
 
 
 
32
 
33
  def format_size(num: int) -> str:
34
  """Format size in bytes into a human-readable string.
@@ -69,13 +71,12 @@ def load_model_index(pipeline_id, token=None, revision=None):
69
 
70
  def get_individual_model_memory(id, token, variant, extension):
71
  files_in_repo = model_info(id, token=token, files_metadata=True).siblings
72
- for x in files_in_repo:
73
- if extension in x.rfilename:
74
- if variant:
75
- if variant in x.rfilename:
76
- return x.size
77
- else:
78
- return x.size
79
 
80
 
81
  def get_component_wise_memory(
@@ -211,7 +212,7 @@ with gr.Interface(
211
  gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
212
  gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
213
  gr.components.Radio(
214
- ["fp32", "fp16", "bf16"],
215
  label="variant",
216
  info="Precision to use for calculation.",
217
  ),
@@ -232,7 +233,7 @@ with gr.Interface(
232
  None,
233
  "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
234
  None,
235
- "fp32",
236
  None,
237
  ".safetensors",
238
  ],
 
29
  `controlnet_id` or `t2i_adapter_id`.
30
  """
31
 
32
+ ALLOWED_VARIANTS = ["fp32", "fp16", "bf16"]
33
+
34
 
35
  def format_size(num: int) -> str:
36
  """Format size in bytes into a human-readable string.
 
71
 
72
  def get_individual_model_memory(id, token, variant, extension):
73
  files_in_repo = model_info(id, token=token, files_metadata=True).siblings
74
+ candidates = [x for x in files_in_repo if extension in x.rfilename]
75
+ if variant:
76
+ candidate = list(filter(lambda x: variant in x.rfilename, candidates))[0]
77
+ else:
78
+ candidate = list(filter(lambda x: all(var not in x.rfilename for var in ALLOWED_VARIANTS[1:]), candidates))[0]
79
+ return candidate.size
 
80
 
81
 
82
  def get_component_wise_memory(
 
212
  gr.components.Textbox(lines=1, label="t2i_adapter_id", info="Example: TencentARC/t2iadapter_color_sd14v1"),
213
  gr.components.Textbox(lines=1, label="hf_token", info="Pass this in case of private repositories."),
214
  gr.components.Radio(
215
+ ALLOWED_VARIANTS,
216
  label="variant",
217
  info="Precision to use for calculation.",
218
  ),
 
233
  None,
234
  "TencentARC/t2i-adapter-lineart-sdxl-1.0,TencentARC/t2i-adapter-canny-sdxl-1.0",
235
  None,
236
+ "fp16",
237
  None,
238
  ".safetensors",
239
  ],