JVice commited on
Commit
1fdf512
·
verified ·
1 Parent(s): 03b145d

Added LatentConsistencyModel Pipeline functionality

Browse files
Files changed (1) hide show
  1. model_loading.py +6 -2
model_loading.py CHANGED
@@ -7,7 +7,8 @@ if torch.cuda.is_available():
7
  else:
8
  device = 'cpu'
9
 
10
- validT2IModelTypes = ["KandinskyPipeline", "StableDiffusionPipeline", "DiffusionPipeline", "StableDiffusionXLPipeline"]
 
11
  def check_if_model_exists(repoName):
12
  modelLoaded = None
13
  huggingFaceURL = "https://huggingface.co/" + repoName + "/raw/main/model_index.json"
@@ -44,8 +45,11 @@ def import_model(modelID, modelType):
44
  if modelType == 'StableDiffusionXLPipeline':
45
  from diffusers import StableDiffusionXLPipeline
46
  T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
 
 
 
47
  else:
48
  from diffusers import AutoPipelineForText2Image
49
  T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
50
- T2IModel.to(device)
51
  return T2IModel
 
7
  else:
8
  device = 'cpu'
9
 
10
+ validT2IModelTypes = ["KandinskyPipeline", "StableDiffusionPipeline", "DiffusionPipeline", "StableDiffusionXLPipeline",
11
+ "LatentConsistencyModelPipeline"]
12
  def check_if_model_exists(repoName):
13
  modelLoaded = None
14
  huggingFaceURL = "https://huggingface.co/" + repoName + "/raw/main/model_index.json"
 
45
  if modelType == 'StableDiffusionXLPipeline':
46
  from diffusers import StableDiffusionXLPipeline
47
  T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
48
+ elif modelType == 'LatentConsistencyModelPipeline':
49
+ from diffusers import DiffusionPipeline
50
+ T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
51
  else:
52
  from diffusers import AutoPipelineForText2Image
53
  T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
54
+ T2IModel.to("cuda")
55
  return T2IModel