|
from functools import partial |
|
|
|
import torch |
|
from diffusers import StableDiffusionXLKDiffusionPipeline, UNet2DConditionModel |
|
from k_diffusion.sampling import get_sigmas_polyexponential |
|
from k_diffusion.sampling import sample_dpmpp_2m_sde |
|
|
|
torch.set_float32_matmul_precision("medium") |
|
|
|
|
|
def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None): |
|
self.num_inference_steps = num_inference_steps |
|
|
|
self.sigmas = get_sigmas_polyexponential( |
|
num_inference_steps + 1, |
|
sigma_min=orig_sigmas[-2], |
|
sigma_max=orig_sigmas[0], |
|
rho=0.666666, |
|
device=device or "cpu", |
|
) |
|
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])]) |
|
|
|
|
|
def model_forward(k_diffusion_model: torch.nn.Module): |
|
orig_forward = k_diffusion_model.forward |
|
|
|
def forward(*args, **kwargs): |
|
with torch.autocast(device_type="cuda", dtype=torch.float16): |
|
result = orig_forward(*args, **kwargs) |
|
return result.float() |
|
|
|
return forward |
|
|
|
|
|
def load_model(model_id="KBlueLeaf/Kohaku-XL-Zeta", device="cuda"): |
|
pipe: StableDiffusionXLKDiffusionPipeline |
|
pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained( |
|
model_id, torch_dtype=torch.float16 |
|
).to(device) |
|
unet: UNet2DConditionModel = pipe.k_diffusion_model.inner_model.model |
|
pipe.scheduler.set_timesteps = partial( |
|
set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas |
|
) |
|
pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun") |
|
pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model) |
|
return pipe |
|
|
|
|
|
def encode_prompts(pipe: StableDiffusionXLKDiffusionPipeline, prompt, neg_prompt): |
|
max_length = pipe.tokenizer.model_max_length |
|
|
|
input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda") |
|
input_ids2 = pipe.tokenizer_2(prompt, return_tensors="pt").input_ids.to("cuda") |
|
|
|
negative_ids = pipe.tokenizer( |
|
neg_prompt, |
|
truncation=False, |
|
padding="max_length", |
|
max_length=input_ids.shape[-1], |
|
return_tensors="pt", |
|
).input_ids.to("cuda") |
|
negative_ids2 = pipe.tokenizer_2( |
|
neg_prompt, |
|
truncation=False, |
|
padding="max_length", |
|
max_length=input_ids.shape[-1], |
|
return_tensors="pt", |
|
).input_ids.to("cuda") |
|
|
|
if negative_ids.size() > input_ids.size(): |
|
input_ids = pipe.tokenizer( |
|
prompt, |
|
truncation=False, |
|
padding="max_length", |
|
max_length=negative_ids.shape[-1], |
|
return_tensors="pt", |
|
).input_ids.to("cuda") |
|
input_ids2 = pipe.tokenizer_2( |
|
prompt, |
|
truncation=False, |
|
padding="max_length", |
|
max_length=negative_ids.shape[-1], |
|
return_tensors="pt", |
|
).input_ids.to("cuda") |
|
|
|
concat_embeds = [] |
|
neg_embeds = [] |
|
for i in range(0, input_ids.shape[-1], max_length): |
|
concat_embeds.append(pipe.text_encoder(input_ids[:, i : i + max_length])[0]) |
|
neg_embeds.append(pipe.text_encoder(negative_ids[:, i : i + max_length])[0]) |
|
|
|
concat_embeds2 = [] |
|
neg_embeds2 = [] |
|
pooled_embeds2 = [] |
|
neg_pooled_embeds2 = [] |
|
for i in range(0, input_ids.shape[-1], max_length): |
|
hidden_states = pipe.text_encoder_2( |
|
input_ids2[:, i : i + max_length], output_hidden_states=True |
|
) |
|
concat_embeds2.append(hidden_states.hidden_states[-2]) |
|
pooled_embeds2.append(hidden_states[0]) |
|
|
|
hidden_states = pipe.text_encoder_2( |
|
negative_ids2[:, i : i + max_length], output_hidden_states=True |
|
) |
|
neg_embeds2.append(hidden_states.hidden_states[-2]) |
|
neg_pooled_embeds2.append(hidden_states[0]) |
|
|
|
prompt_embeds = torch.cat(concat_embeds, dim=1) |
|
negative_prompt_embeds = torch.cat(neg_embeds, dim=1) |
|
prompt_embeds2 = torch.cat(concat_embeds2, dim=1) |
|
negative_prompt_embeds2 = torch.cat(neg_embeds2, dim=1) |
|
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1) |
|
negative_prompt_embeds = torch.cat( |
|
[negative_prompt_embeds, negative_prompt_embeds2], dim=-1 |
|
) |
|
|
|
pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0) |
|
neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0) |
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 |
|
|