multimodalart HF staff commited on
Commit
b9ca2c3
1 Parent(s): c64188e

fix dtype in migration

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -194,7 +194,8 @@ You should use {formatted_words} to trigger the image generation.
194
  url: >-
195
  {image}
196
  """
197
-
 
198
  content = f"""---
199
  license: other
200
  license_name: bespoke-lora-trained-license
@@ -234,7 +235,9 @@ Weights for this model are available in Safetensors format.
234
  from diffusers import AutoPipelineForText2Image
235
  import torch
236
 
237
- pipeline = AutoPipelineForText2Image.from_pretrained('{info["baseModel"]}', torch_dtype=torch.float16).to('cuda')
 
 
238
  pipeline.load_lora_weights('{user_repo_id}', weight_name='{downloaded_files["weightName"][0]}')
239
  image = pipeline('{prompt if prompt else (formatted_words if formatted_words else 'Your custom prompt')}').images[0]
240
  ```
 
194
  url: >-
195
  {image}
196
  """
197
+ dtype = "torch.bfloat16" if info["baseModel"] == "black-forest-labs/FLUX.1-dev" or info["baseModel"] == "black-forest-labs/FLUX.1-schnell" else "torch.float16"
198
+
199
  content = f"""---
200
  license: other
201
  license_name: bespoke-lora-trained-license
 
235
  from diffusers import AutoPipelineForText2Image
236
  import torch
237
 
238
+ device = "cuda" if torch.cuda.is_available() else "cpu"
239
+
240
+ pipeline = AutoPipelineForText2Image.from_pretrained('{info["baseModel"]}', torch_dtype={dtype}).to(device)
241
  pipeline.load_lora_weights('{user_repo_id}', weight_name='{downloaded_files["weightName"][0]}')
242
  image = pipeline('{prompt if prompt else (formatted_words if formatted_words else 'Your custom prompt')}').images[0]
243
  ```