File size: 5,587 Bytes
06f0d78 a1372fa a3ce076 a1372fa a3ce076 a1372fa 06f0d78 a1372fa 06f0d78 a1372fa 06f0d78 a1372fa 06f0d78 a1372fa 06f0d78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import math
from functools import partial
import torch
from diffusers import StableDiffusionXLKDiffusionPipeline, UNet2DConditionModel
from k_diffusion.sampling import get_sigmas_polyexponential
from k_diffusion.sampling import sample_dpmpp_2m_sde
torch.set_float32_matmul_precision("medium")
def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
self.num_inference_steps = num_inference_steps
self.sigmas = get_sigmas_polyexponential(
num_inference_steps + 1,
sigma_min=orig_sigmas[-2],
sigma_max=orig_sigmas[0],
rho=0.666666,
device=device or "cpu",
)
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])
def model_forward(k_diffusion_model: torch.nn.Module):
orig_forward = k_diffusion_model.forward
def forward(*args, **kwargs):
with torch.autocast(device_type="cuda", dtype=torch.float16):
result = orig_forward(*args, **kwargs)
return result.float()
return forward
def load_model(model_id="KBlueLeaf/Kohaku-XL-Zeta", device="cuda"):
pipe: StableDiffusionXLKDiffusionPipeline
pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
).to(device)
unet: UNet2DConditionModel = pipe.k_diffusion_model.inner_model.model
pipe.scheduler.set_timesteps = partial(
set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
)
pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model)
return pipe
@torch.no_grad()
def encode_prompts(
pipe: StableDiffusionXLKDiffusionPipeline, prompt: str, neg_prompt: str = ""
):
prompts = [prompt, neg_prompt]
max_length = pipe.tokenizer.model_max_length - 2
input_ids = pipe.tokenizer(prompts, padding=True, return_tensors="pt")
input_ids2 = pipe.tokenizer_2(prompts, padding=True, return_tensors="pt")
length = max(input_ids.input_ids.size(-1), input_ids2.input_ids.size(-1))
target_length = math.ceil(length / max_length) * max_length + 2
input_ids = pipe.tokenizer(
prompts, padding="max_length", max_length=target_length, return_tensors="pt"
).input_ids
input_ids = (
input_ids[:, 0:1],
input_ids[:, 1:-1],
input_ids[:, -1:],
)
input_ids2 = pipe.tokenizer_2(
prompts, padding="max_length", max_length=target_length, return_tensors="pt"
).input_ids
input_ids2 = (
input_ids2[:, 0:1],
input_ids2[:, 1:-1],
input_ids2[:, -1:],
)
concat_embeds = []
for i in range(0, input_ids[1].shape[-1], max_length):
input_id1 = torch.concat(
(input_ids[0], input_ids[1][:, i : i + max_length], input_ids[2]), dim=-1
).to(pipe.device)
result = pipe.text_encoder(input_id1, output_hidden_states=True).hidden_states[
-2
]
if i == 0:
concat_embeds.append(result[:, :-1])
elif i == input_ids[1].shape[-1] - max_length:
concat_embeds.append(result[:, 1:])
else:
concat_embeds.append(result[:, 1:-1])
concat_embeds2 = []
pooled_embeds2 = []
for i in range(0, input_ids2[1].shape[-1], max_length):
input_id2 = torch.concat(
(input_ids2[0], input_ids2[1][:, i : i + max_length], input_ids2[2]), dim=-1
).to(pipe.device)
hidden_states = pipe.text_encoder_2(input_id2, output_hidden_states=True)
pooled_embeds2.append(hidden_states[0])
if i == 0:
concat_embeds2.append(hidden_states.hidden_states[-2][:, :-1])
elif i == input_ids2[1].shape[-1] - max_length:
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:])
else:
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:-1])
prompt_embeds = torch.cat(concat_embeds, dim=1)
prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
return prompt_embeds, pooled_embeds2
if __name__ == "__main__":
from meta import DEFAULT_NEGATIVE_PROMPT
prompt = """
1girl,
king halo (umamusume), umamusume,
ogipote, misu kasumi, fuzichoco, ciloranko, ninjin nouka, ningen mame, ask (askzy), kita (kitairoha), amano kokoko, maccha (mochancc),
solo, leaning forward, cleavage, sky, cowboy shot, outdoors, cloud, long hair, looking at viewer, brown hair, day, horse girl, black bikini, cloudy sky, stomach, collarbone, blue sky, swimsuit, navel, thighs, blush, ocean, animal ears, standing, smile, breasts, open mouth, :d, red eyes, horse ears, tail, bare shoulders, wavy hair, bikini, medium breasts,
masterpiece, newest, absurdres, sensitive
""".strip()
sdxl_pipe = load_model("KBlueLeaf/xxx")
# sdxl_pipe = load_model()
prompt_embeds, pooled_embeds2 = encode_prompts(
sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT
)
result = sdxl_pipe(
prompt_embeds=prompt_embeds[0:1],
negative_prompt_embeds=prompt_embeds[1:],
pooled_prompt_embeds=pooled_embeds2[0:1],
negative_pooled_prompt_embeds=pooled_embeds2[1:],
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
result.save("test.png")
module = torch.compile(sdxl_pipe)
if isinstance(module, torch._dynamo.OptimizedModule):
original_module = module._orig_mod
|