JVice commited on
Commit
7cd78a5
1 Parent(s): a6670ea

Update model_loading.py

Browse files

Added support for 'StableDiffusion3Pipeline' and 'FluxPipeline' models

Files changed (1) hide show
  1. model_loading.py +7 -1
model_loading.py CHANGED
@@ -8,7 +8,7 @@ else:
8
  device = 'cpu'
9
 
10
  validT2IModelTypes = ["KandinskyPipeline", "StableDiffusionPipeline", "DiffusionPipeline", "StableDiffusionXLPipeline",
11
- "LatentConsistencyModelPipeline"]
12
  def check_if_model_exists(repoName):
13
  modelLoaded = None
14
  huggingFaceURL = "https://huggingface.co/" + repoName + "/raw/main/model_index.json"
@@ -40,6 +40,12 @@ def import_model(modelID, modelType):
40
  elif modelType == 'LatentConsistencyModelPipeline':
41
  from diffusers import DiffusionPipeline
42
  T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
 
 
 
 
 
 
43
  else:
44
  from diffusers import AutoPipelineForText2Image
45
  T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
 
8
  device = 'cpu'
9
 
10
  validT2IModelTypes = ["KandinskyPipeline", "StableDiffusionPipeline", "DiffusionPipeline", "StableDiffusionXLPipeline",
11
+ "LatentConsistencyModelPipeline","StableDiffusion3Pipeline", "FluxPipeline"]
12
  def check_if_model_exists(repoName):
13
  modelLoaded = None
14
  huggingFaceURL = "https://huggingface.co/" + repoName + "/raw/main/model_index.json"
 
40
  elif modelType == 'LatentConsistencyModelPipeline':
41
  from diffusers import DiffusionPipeline
42
  T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
43
+ elif modelType == 'StableDiffusion3Pipeline':
44
+ from diffusers import StableDiffusion3Pipeline
45
+ T2IModel = StableDiffusion3Pipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
46
+ elif modelType == 'FluxPipeline':
47
+ from diffusers import FluxPipeline
48
+ T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
49
  else:
50
  from diffusers import AutoPipelineForText2Image
51
  T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)