File size: 1,892 Bytes
2494346
bbedeb2
 
 
9905ad3
bbedeb2
 
 
 
9905ad3
bbedeb2
 
d062d3a
cd49d85
9905ad3
ca685d6
bbedeb2
 
 
13d1420
bbedeb2
 
 
ca685d6
 
 
01497eb
bbedeb2
827a7ed
7d697f7
 
 
16e8fe9
827a7ed
 
01497eb
7d697f7
01497eb
13d1420
 
bbedeb2
 
ca685d6
bbedeb2
 
2d98a52
01497eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import modules.core as core
import os
import torch

from modules.path import modelfile_path, lorafile_path


xl_base_filename = os.path.join(modelfile_path, 'sd_xl_base_1.0.safetensors')
xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0.safetensors')
xl_base_offset_lora_filename = os.path.join(lorafile_path, 'sd_xl_offset_example-lora_1.0.safetensors')

xl_base = core.load_model(xl_base_filename)
xl_base = core.load_lora(xl_base, xl_base_offset_lora_filename, strength_model=0.5, strength_clip=0.0)
del xl_base.vae

xl_refiner = core.load_model(xl_refiner_filename)


@torch.no_grad()
def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed, callback):
    positive_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=positive_prompt)
    negative_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=negative_prompt)

    positive_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=positive_prompt)
    negative_conditions_refiner = core.encode_prompt_condition(clip=xl_refiner.clip, prompt=negative_prompt)

    empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1)

    sampled_latent = core.ksampler_with_refiner(
        model=xl_base.unet,
        positive=positive_conditions,
        negative=negative_conditions,
        refiner=xl_refiner.unet,
        refiner_positive=positive_conditions_refiner,
        refiner_negative=negative_conditions_refiner,
        refiner_switch_step=switch,
        latent=empty_latent,
        steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
        seed=image_seed,
        callback_function=callback
    )

    decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent)

    images = core.image_to_numpy(decoded_latent)

    return images