import math from functools import partial import torch from diffusers import StableDiffusionXLKDiffusionPipeline, UNet2DConditionModel from k_diffusion.sampling import get_sigmas_polyexponential from k_diffusion.sampling import sample_dpmpp_2m_sde torch.set_float32_matmul_precision("medium") def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None): self.num_inference_steps = num_inference_steps self.sigmas = get_sigmas_polyexponential( num_inference_steps + 1, sigma_min=orig_sigmas[-2], sigma_max=orig_sigmas[0], rho=0.666666, device=device or "cpu", ) self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])]) def model_forward(k_diffusion_model: torch.nn.Module): orig_forward = k_diffusion_model.forward def forward(*args, **kwargs): with torch.autocast(device_type="cuda", dtype=torch.float16): result = orig_forward(*args, **kwargs) return result.float() return forward def load_model(model_id="KBlueLeaf/Kohaku-XL-Zeta", device="cuda"): pipe: StableDiffusionXLKDiffusionPipeline pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 ).to(device) unet: UNet2DConditionModel = pipe.k_diffusion_model.inner_model.model pipe.scheduler.set_timesteps = partial( set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas ) pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun") pipe.k_diffusion_model.forward = model_forward(pipe.k_diffusion_model) return pipe @torch.no_grad() def encode_prompts( pipe: StableDiffusionXLKDiffusionPipeline, prompt: str, neg_prompt: str = "" ): prompts = [prompt, neg_prompt] max_length = pipe.tokenizer.model_max_length - 2 input_ids = pipe.tokenizer(prompts, padding=True, return_tensors="pt") input_ids2 = pipe.tokenizer_2(prompts, padding=True, return_tensors="pt") length = max(input_ids.input_ids.size(-1), input_ids2.input_ids.size(-1)) target_length = math.ceil(length / max_length) * max_length + 2 input_ids = pipe.tokenizer( prompts, padding="max_length", max_length=target_length, return_tensors="pt" ).input_ids input_ids = ( input_ids[:, 0:1], input_ids[:, 1:-1], input_ids[:, -1:], ) input_ids2 = pipe.tokenizer_2( prompts, padding="max_length", max_length=target_length, return_tensors="pt" ).input_ids input_ids2 = ( input_ids2[:, 0:1], input_ids2[:, 1:-1], input_ids2[:, -1:], ) concat_embeds = [] for i in range(0, input_ids[1].shape[-1], max_length): input_id1 = torch.concat( (input_ids[0], input_ids[1][:, i : i + max_length], input_ids[2]), dim=-1 ).to(pipe.device) result = pipe.text_encoder(input_id1, output_hidden_states=True).hidden_states[ -2 ] if i == 0: concat_embeds.append(result[:, :-1]) elif i == input_ids[1].shape[-1] - max_length: concat_embeds.append(result[:, 1:]) else: concat_embeds.append(result[:, 1:-1]) concat_embeds2 = [] pooled_embeds2 = [] for i in range(0, input_ids2[1].shape[-1], max_length): input_id2 = torch.concat( (input_ids2[0], input_ids2[1][:, i : i + max_length], input_ids2[2]), dim=-1 ).to(pipe.device) hidden_states = pipe.text_encoder_2(input_id2, output_hidden_states=True) pooled_embeds2.append(hidden_states[0]) if i == 0: concat_embeds2.append(hidden_states.hidden_states[-2][:, :-1]) elif i == input_ids2[1].shape[-1] - max_length: concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:]) else: concat_embeds2.append(hidden_states.hidden_states[-2][:, 1:-1]) prompt_embeds = torch.cat(concat_embeds, dim=1) prompt_embeds2 = torch.cat(concat_embeds2, dim=1) prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1) pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0) return prompt_embeds, pooled_embeds2 if __name__ == "__main__": from meta import DEFAULT_NEGATIVE_PROMPT prompt = """ 1girl, king halo (umamusume), umamusume, ogipote, misu kasumi, fuzichoco, ciloranko, ninjin nouka, ningen mame, ask (askzy), kita (kitairoha), amano kokoko, maccha (mochancc), solo, leaning forward, cleavage, sky, cowboy shot, outdoors, cloud, long hair, looking at viewer, brown hair, day, horse girl, black bikini, cloudy sky, stomach, collarbone, blue sky, swimsuit, navel, thighs, blush, ocean, animal ears, standing, smile, breasts, open mouth, :d, red eyes, horse ears, tail, bare shoulders, wavy hair, bikini, medium breasts, masterpiece, newest, absurdres, sensitive """.strip() sdxl_pipe = load_model("KBlueLeaf/xxx") # sdxl_pipe = load_model() prompt_embeds, pooled_embeds2 = encode_prompts( sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT ) result = sdxl_pipe( prompt_embeds=prompt_embeds[0:1], negative_prompt_embeds=prompt_embeds[1:], pooled_prompt_embeds=pooled_embeds2[0:1], negative_pooled_prompt_embeds=pooled_embeds2[1:], num_inference_steps=24, width=1024, height=1024, guidance_scale=6.0, ).images[0] result.save("test.png") module = torch.compile(sdxl_pipe) if isinstance(module, torch._dynamo.OptimizedModule): original_module = module._orig_mod