JVice commited on
Commit
485a2af
·
verified ·
1 Parent(s): 70419e4

Update model_loading.py

Browse files
Files changed (1) hide show
  1. model_loading.py +11 -0
model_loading.py CHANGED
@@ -37,12 +37,15 @@ def import_model(modelID, modelType):
37
  if modelType == 'StableDiffusionXLPipeline':
38
  from diffusers import StableDiffusionXLPipeline
39
  T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
 
40
  elif modelType == 'LatentConsistencyModelPipeline':
41
  from diffusers import DiffusionPipeline
42
  T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
 
43
  elif modelType == 'StableDiffusion3Pipeline':
44
  from diffusers import StableDiffusion3Pipeline
45
  T2IModel = StableDiffusion3Pipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
 
46
  elif modelType == 'FluxPipeline':
47
  from diffusers import FluxPipeline
48
  T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
@@ -50,6 +53,14 @@ def import_model(modelID, modelType):
50
  else:
51
  from diffusers import AutoPipelineForText2Image
52
  T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
 
53
 
 
 
 
54
  T2IModel.to("cuda")
 
 
 
 
55
  return T2IModel
 
37
  if modelType == 'StableDiffusionXLPipeline':
38
  from diffusers import StableDiffusionXLPipeline
39
  T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
40
+ T2IModel.to("cuda")
41
  elif modelType == 'LatentConsistencyModelPipeline':
42
  from diffusers import DiffusionPipeline
43
  T2IModel = DiffusionPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
44
+ T2IModel.to("cuda")
45
  elif modelType == 'StableDiffusion3Pipeline':
46
  from diffusers import StableDiffusion3Pipeline
47
  T2IModel = StableDiffusion3Pipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
48
+ T2IModel.to("cuda")
49
  elif modelType == 'FluxPipeline':
50
  from diffusers import FluxPipeline
51
  T2IModel = FluxPipeline.from_pretrained(modelID, torch_dtype=torch.bfloat16)
 
53
  else:
54
  from diffusers import AutoPipelineForText2Image
55
  T2IModel = AutoPipelineForText2Image.from_pretrained(modelID, torch_dtype=torch.float16)
56
+ T2IModel.to("cuda")
57
 
58
+ if 'StableDiffusionXLPipeline' in modelType.split(','):
59
+ from diffusers import StableDiffusionXLPipeline
60
+ T2IModel = StableDiffusionXLPipeline.from_pretrained(modelID, torch_dtype=torch.float16)
61
  T2IModel.to("cuda")
62
+ try:
63
+ T2IModel.safety_checker = None
64
+ except:
65
+ pass # if the model does not contain a safety checker no need to remove it
66
  return T2IModel