lllyasviel commited on
Commit
9905ad3
·
1 Parent(s): 244a6d4
Files changed (2) hide show
  1. modules/core.py +11 -1
  2. modules/default_pipeline.py +4 -1
modules/core.py CHANGED
@@ -10,7 +10,7 @@ import numpy as np
10
  import comfy.model_management
11
  import comfy.utils
12
 
13
- from comfy.sd import load_checkpoint_guess_config
14
  from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
15
  from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
16
  from modules.samplers_advanced import KSampler, KSamplerWithRefiner
@@ -39,6 +39,16 @@ def load_model(ckpt_filename):
39
  return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
40
 
41
 
 
 
 
 
 
 
 
 
 
 
42
  @torch.no_grad()
43
  def encode_prompt_condition(clip, prompt):
44
  return opCLIPTextEncode.encode(clip=clip, text=prompt)[0]
 
10
  import comfy.model_management
11
  import comfy.utils
12
 
13
+ from comfy.sd import load_checkpoint_guess_config, load_lora_for_models
14
  from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
15
  from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
16
  from modules.samplers_advanced import KSampler, KSamplerWithRefiner
 
39
  return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
40
 
41
 
42
+ @torch.no_grad()
43
+ def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
44
+ if strength_model == 0 and strength_clip == 0:
45
+ return model
46
+
47
+ lora = comfy.utils.load_torch_file(lora_filename, safe_load=True)
48
+ model.unet, model.clip = comfy.sd.load_lora_for_models(model.unet, model.clip, lora, strength_model, strength_clip)
49
+ return model
50
+
51
+
52
  @torch.no_grad()
53
  def encode_prompt_condition(clip, prompt):
54
  return opCLIPTextEncode.encode(clip=clip, text=prompt)[0]
modules/default_pipeline.py CHANGED
@@ -2,13 +2,16 @@ import modules.core as core
2
  import os
3
  import torch
4
 
5
- from modules.path import modelfile_path
6
 
7
 
8
  xl_base_filename = os.path.join(modelfile_path, 'sd_xl_base_1.0.safetensors')
9
  xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0.safetensors')
 
10
 
11
  xl_base = core.load_model(xl_base_filename)
 
 
12
  xl_refiner = core.load_model(xl_refiner_filename)
13
  del xl_base.vae
14
 
 
2
  import os
3
  import torch
4
 
5
+ from modules.path import modelfile_path, lorafile_path
6
 
7
 
8
  xl_base_filename = os.path.join(modelfile_path, 'sd_xl_base_1.0.safetensors')
9
  xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0.safetensors')
10
+ xl_base_offset_lora_filename = os.path.join(lorafile_path, 'sd_xl_offset_example-lora_1.0.safetensors')
11
 
12
  xl_base = core.load_model(xl_base_filename)
13
+ xl_base = core.load_lora(xl_base, xl_base_offset_lora_filename, strength_model=0.618, strength_clip=0.0)
14
+
15
  xl_refiner = core.load_model(xl_refiner_filename)
16
  del xl_base.vae
17