|
import math |
|
from functools import partial |
|
|
|
import torch |
|
from diffusers import StableDiffusionXLKDiffusionPipeline, UNet2DConditionModel |
|
from k_diffusion.sampling import get_sigmas_polyexponential |
|
from k_diffusion.sampling import sample_dpmpp_2m_sde |
|
|
|
torch.set_float32_matmul_precision("medium") |
|
|
|
|
|
def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None): |
|
self.num_inference_steps = num_inference_steps |
|
|
|
self.sigmas = get_sigmas_polyexponential( |
|
num_inference_steps + 1, |
|
sigma_min=orig_sigmas[-2], |
|
sigma_max=orig_sigmas[0], |
|
rho=0.666666, |
|
device=device or "cpu", |
|
) |
|
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])]) |
|
|
|
|
|
def model_forward(k_diffusion_model: torch.nn.Module): |
|
orig_forward = k_diffusion_model.forward |
|
|
|
def forward(*args, **kwargs): |
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
result = orig_forward(*args, **kwargs) |
|
return result.float() |
|
|
|
return forward |
|
|
|
|
|
def load_model(model_id="KBlueLeaf/Kohaku-XL-Zeta", device="cuda"): |
|
pipe: StableDiffusionXLKDiffusionPipeline |
|
pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained( |
|
model_id, torch_dtype=torch.float16 |
|
).to(device) |
|
unet: UNet2DConditionModel = pipe.k_diffusion_model.inner_model.model |
|
pipe.scheduler.set_timesteps = partial( |
|
set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas |
|
) |
|
pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun") |
|
pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model) |
|
return pipe |
|
|
|
|
|
@torch.no_grad() |
|
def encode_prompts( |
|
pipe: StableDiffusionXLKDiffusionPipeline, prompt: str, neg_prompt: str = "" |
|
): |
|
prompts = [prompt, neg_prompt] |
|
max_length = pipe.tokenizer.model_max_length - 2 |
|
|
|
input_ids = pipe.tokenizer(prompts, padding=True, return_tensors="pt") |
|
input_ids2 = pipe.tokenizer_2(prompts, padding=True, return_tensors="pt") |
|
length = max(input_ids.input_ids.size(-1), input_ids2.input_ids.size(-1)) |
|
target_length = math.ceil(length / max_length) * max_length + 2 |
|
|
|
input_ids = pipe.tokenizer( |
|
prompts, padding="max_length", max_length=target_length, return_tensors="pt" |
|
).input_ids |
|
input_ids = ( |
|
input_ids[:, 0:1], |
|
input_ids[:, 1:-1], |
|
input_ids[:, -1:], |
|
) |
|
input_ids2 = pipe.tokenizer_2( |
|
prompts, padding="max_length", max_length=target_length, return_tensors="pt" |
|
).input_ids |
|
input_ids2 = ( |
|
input_ids2[:, 0:1], |
|
input_ids2[:, 1:-1], |
|
input_ids2[:, -1:], |
|
) |
|
|
|
concat_embeds = [] |
|
for i in range(0, input_ids[1].shape[-1], max_length): |
|
input_id1 = torch.concat( |
|
(input_ids[0], input_ids[1][:, i : i + max_length], input_ids[2]), dim=-1 |
|
).to(pipe.device) |
|
result = pipe.text_encoder(input_id1, output_hidden_states=True).hidden_states[ |
|
-2 |
|
] |
|
if i == 0: |
|
concat_embeds.append(result[:, :-1]) |
|
elif i == input_ids[1].shape[-1] - max_length: |
|
concat_embeds.append(result[:, 1:]) |
|
else: |
|
concat_embeds.append(result[:, 1:-1]) |
|
|
|
concat_embeds2 = [] |
|
pooled_embeds2 = [] |
|
for i in range(0, input_ids2[1].shape[-1], max_length): |
|
input_id2 = torch.concat( |
|
(input_ids2[0], input_ids2[1][:, i : i + max_length], input_ids2[2]), dim=-1 |
|
).to(pipe.device) |
|
hidden_states = pipe.text_encoder_2(input_id2, output_hidden_states=True) |
|
pooled_embeds2.append(hidden_states[0]) |
|
if i == 0: |
|
concat_embeds2.append(hidden_states.hidden_states[-2][:, :-1]) |
|
elif i == input_ids2[1].shape[-1] - max_length: |
|
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:]) |
|
else: |
|
concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:-1]) |
|
|
|
prompt_embeds = torch.cat(concat_embeds, dim=1) |
|
prompt_embeds2 = torch.cat(concat_embeds2, dim=1) |
|
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1) |
|
|
|
pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0) |
|
|
|
return prompt_embeds, pooled_embeds2 |
|
|
|
|
|
if __name__ == "__main__": |
|
from meta import DEFAULT_NEGATIVE_PROMPT |
|
prompt = """ |
|
1girl, |
|
king halo (umamusume), umamusume, |
|
|
|
ogipote, misu kasumi, fuzichoco, ciloranko, ninjin nouka, ningen mame, ask (askzy), kita (kitairoha), amano kokoko, maccha (mochancc), |
|
|
|
solo, leaning forward, cleavage, sky, cowboy shot, outdoors, cloud, long hair, looking at viewer, brown hair, day, horse girl, black bikini, cloudy sky, stomach, collarbone, blue sky, swimsuit, navel, thighs, blush, ocean, animal ears, standing, smile, breasts, open mouth, :d, red eyes, horse ears, tail, bare shoulders, wavy hair, bikini, medium breasts, |
|
|
|
masterpiece, newest, absurdres, sensitive |
|
""".strip() |
|
sdxl_pipe = load_model("KBlueLeaf/xxx") |
|
|
|
prompt_embeds, pooled_embeds2 = encode_prompts( |
|
sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT |
|
) |
|
result = sdxl_pipe( |
|
prompt_embeds=prompt_embeds[0:1], |
|
negative_prompt_embeds=prompt_embeds[1:], |
|
pooled_prompt_embeds=pooled_embeds2[0:1], |
|
negative_pooled_prompt_embeds=pooled_embeds2[1:], |
|
num_inference_steps=24, |
|
width=1024, |
|
height=1024, |
|
guidance_scale=6.0, |
|
).images[0] |
|
|
|
result.save("test.png") |
|
|
|
module = torch.compile(sdxl_pipe) |
|
if isinstance(module, torch._dynamo.OptimizedModule): |
|
original_module = module._orig_mod |
|
|