Spaces:
Paused
Paused
File size: 5,822 Bytes
2494346 bbedeb2 2f3af37 bbedeb2 2f3af37 bbedeb2 2f3af37 bbedeb2 2f3af37 9905ad3 2f3af37 bbedeb2 59406a8 bbedeb2 13d1420 59406a8 ca685d6 01497eb bbedeb2 2f3af37 59406a8 2f3af37 59406a8 2f3af37 bbedeb2 2d98a52 01497eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import modules.core as core
import os
import torch
import modules.path
from comfy.model_base import SDXL, SDXLRefiner
xl_base: core.StableDiffusionModel = None
xl_base_hash = ''
xl_refiner: core.StableDiffusionModel = None
xl_refiner_hash = ''
xl_base_patched: core.StableDiffusionModel = None
xl_base_patched_hash = ''
def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
if xl_base_hash == str(name):
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_base is not None:
xl_base.to_meta()
xl_base = None
xl_base = core.load_model(filename)
if not isinstance(xl_base.unet.model, SDXL):
print('Model not supported. Fooocus only support SDXL model as the base model.')
xl_base = None
xl_base_hash = ''
refresh_base_model(modules.path.default_base_model_name)
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
return
xl_base_hash = name
xl_base_patched = xl_base
xl_base_patched_hash = ''
print(f'Base model loaded: {xl_base_hash}')
return
def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash
if xl_refiner_hash == str(name):
return
if name == 'None':
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
filename = os.path.join(modules.path.modelfile_path, name)
if xl_refiner is not None:
xl_refiner.to_meta()
xl_refiner = None
xl_refiner = core.load_model(filename)
if not isinstance(xl_refiner.unet.model, SDXLRefiner):
print('Model not supported. Fooocus only support SDXL refiner as the refiner.')
xl_refiner = None
xl_refiner_hash = ''
print(f'Refiner unloaded.')
return
xl_refiner_hash = name
print(f'Refiner model loaded: {xl_refiner_hash}')
xl_refiner.vae.first_stage_model.to('meta')
xl_refiner.vae = None
return
def refresh_loras(loras):
global xl_base, xl_base_patched, xl_base_patched_hash
if xl_base_patched_hash == str(loras):
return
model = xl_base
for name, weight in loras:
if name == 'None':
continue
filename = os.path.join(modules.path.lorafile_path, name)
model = core.load_lora(model, filename, strength_model=weight, strength_clip=weight)
xl_base_patched = model
xl_base_patched_hash = str(loras)
print(f'LoRAs loaded: {xl_base_patched_hash}')
return
refresh_base_model(modules.path.default_base_model_name)
refresh_refiner_model(modules.path.default_refiner_model_name)
refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)])
positive_conditions_cache = None
negative_conditions_cache = None
positive_conditions_refiner_cache = None
negative_conditions_refiner_cache = None
def clean_prompt_cond_caches():
global positive_conditions_cache, negative_conditions_cache, \
positive_conditions_refiner_cache, negative_conditions_refiner_cache
positive_conditions_cache = None
negative_conditions_cache = None
positive_conditions_refiner_cache = None
negative_conditions_refiner_cache = None
return
@torch.no_grad()
def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed, callback):
global positive_conditions_cache, negative_conditions_cache, \
positive_conditions_refiner_cache, negative_conditions_refiner_cache
positive_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=positive_prompt) if positive_conditions_cache is None else positive_conditions_cache
negative_conditions = core.encode_prompt_condition(clip=xl_base_patched.clip, prompt=negative_prompt) if negative_conditions_cache is None else negative_conditions_cache
positive_conditions_cache = positive_conditions
negative_conditions_cache = negative_conditions
empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)
if xl_refiner is not None:
positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt) if positive_conditions_refiner_cache is None else positive_conditions_refiner_cache
negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt) if negative_conditions_refiner_cache is None else negative_conditions_refiner_cache
positive_conditions_refiner_cache = positive_conditions_refiner
negative_conditions_refiner_cache = negative_conditions_refiner
sampled_latent = core.ksampler_with_refiner(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
refiner=xl_refiner.unet,
refiner_positive=positive_conditions_refiner,
refiner_negative=negative_conditions_refiner,
refiner_switch_step=switch,
latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed,
callback_function=callback
)
else:
sampled_latent = core.ksampler(
model=xl_base_patched.unet,
positive=positive_conditions,
negative=negative_conditions,
latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed,
callback_function=callback
)
decoded_latent = core.decode_vae(vae=xl_base_patched.vae, latent_image=sampled_latent)
images = core.image_to_numpy(decoded_latent)
return images
|