Spaces:
Running
on
Zero
Running
on
Zero
AlekseyCalvin
commited on
Commit
•
cb4a9fb
1
Parent(s):
caaeec1
Update pipeline.py
Browse files- pipeline.py +45 -59
pipeline.py
CHANGED
@@ -292,11 +292,9 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
292 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
293 |
|
294 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
295 |
-
text_ids = torch.zeros(
|
296 |
-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
297 |
-
negative_text_ids = torch.zeros(batch_size, negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
298 |
|
299 |
-
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
300 |
|
301 |
def check_inputs(
|
302 |
self,
|
@@ -485,13 +483,11 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
485 |
self,
|
486 |
prompt: Union[str, List[str]] = None,
|
487 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
|
|
488 |
height: Optional[int] = None,
|
489 |
width: Optional[int] = None,
|
490 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
491 |
-
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
492 |
num_inference_steps: int = 8,
|
493 |
timesteps: List[int] = None,
|
494 |
-
eta: Optional[float] = 0.0,
|
495 |
guidance_scale: float = 3.5,
|
496 |
device: Optional[int] = None,
|
497 |
num_images_per_prompt: Optional[int] = 1,
|
@@ -499,14 +495,13 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
499 |
latents: Optional[torch.FloatTensor] = None,
|
500 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
501 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
502 |
-
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
503 |
-
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
504 |
output_type: Optional[str] = "pil",
|
505 |
cfg: Optional[bool] = True,
|
506 |
return_dict: bool = True,
|
507 |
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
|
508 |
max_sequence_length: int = 512,
|
509 |
-
**kwargs,
|
510 |
):
|
511 |
height = height or self.default_sample_size * self.vae_scale_factor
|
512 |
width = width or self.default_sample_size * self.vae_scale_factor
|
@@ -518,9 +513,8 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
518 |
height,
|
519 |
width,
|
520 |
prompt_embeds=prompt_embeds,
|
521 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
522 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
523 |
-
|
524 |
max_sequence_length=max_sequence_length,
|
525 |
)
|
526 |
|
@@ -546,21 +540,16 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
546 |
pooled_prompt_embeds,
|
547 |
text_ids,
|
548 |
negative_prompt_embeds,
|
549 |
-
negative_pooled_prompt_embeds
|
550 |
-
negative_text_ids,
|
551 |
) = self.encode_prompt(
|
552 |
prompt=prompt,
|
553 |
prompt_2=prompt_2,
|
554 |
num_images_per_prompt=num_images_per_prompt,
|
555 |
max_sequence_length=max_sequence_length,
|
556 |
-
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
557 |
device=device,
|
558 |
negative_prompt=negative_prompt,
|
559 |
-
negative_prompt_2=negative_prompt_2,
|
560 |
prompt_embeds=prompt_embeds,
|
561 |
-
negative_prompt_embeds=negative_prompt_embeds,
|
562 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
563 |
-
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
564 |
lora_scale=lora_scale,
|
565 |
)
|
566 |
|
@@ -607,67 +596,64 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
607 |
for i, t in enumerate(timesteps):
|
608 |
if self.interrupt:
|
609 |
continue
|
610 |
-
|
|
|
611 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
612 |
-
|
613 |
-
timestep = t.expand(latent_model_input.shape[0])
|
614 |
|
615 |
-
|
616 |
-
guidance = torch.tensor([guidance_scale], device=device)
|
617 |
-
guidance = guidance.expand(latents.shape[0])
|
618 |
-
else:
|
619 |
-
guidance = None
|
620 |
-
|
621 |
-
noise_pred_text = self.transformer(
|
622 |
hidden_states=latent_model_input,
|
623 |
timestep=timestep / 1000,
|
624 |
-
|
625 |
-
pooled_projections=pooled_prompt_embeds.shape[1],
|
626 |
encoder_hidden_states=prompt_embeds,
|
627 |
txt_ids=text_ids,
|
628 |
img_ids=latent_image_ids,
|
629 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
630 |
return_dict=False,
|
631 |
)[0]
|
632 |
-
|
633 |
-
hidden_states=latents,
|
634 |
-
timestep=timestep / 1000,
|
635 |
-
guidance=guidance,
|
636 |
-
pooled_projections=negative_pooled_prompt_embeds.shape[1],
|
637 |
-
encoder_hidden_states=negative_prompt_embeds,
|
638 |
-
txt_ids=negative_text_ids,
|
639 |
-
img_ids=latent_image_ids,
|
640 |
-
joint_attention_kwargs=self.joint_attention_kwargs,
|
641 |
-
return_dict=False,
|
642 |
-
)[0]
|
643 |
-
|
644 |
if self.do_classifier_free_guidance:
|
645 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(
|
646 |
-
noise_pred = noise_pred_uncond + self.
|
647 |
-
|
648 |
-
|
649 |
# compute the previous noisy sample x_t -> x_t-1
|
650 |
latents_dtype = latents.dtype
|
651 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
652 |
-
|
653 |
if latents.dtype != latents_dtype:
|
654 |
if torch.backends.mps.is_available():
|
655 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
656 |
latents = latents.to(latents_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
657 |
|
658 |
-
# call the callback, if provided
|
659 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
660 |
progress_bar.update()
|
|
|
|
|
|
|
661 |
|
662 |
-
|
663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
self.maybe_free_model_hooks()
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
671 |
-
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
672 |
-
image = vae.decode(latents, return_dict=False)[0]
|
673 |
-
return self.image_processor.postprocess(image, output_type=output_type)[0]
|
|
|
292 |
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
293 |
|
294 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
295 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
|
|
|
|
296 |
|
297 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
298 |
|
299 |
def check_inputs(
|
300 |
self,
|
|
|
483 |
self,
|
484 |
prompt: Union[str, List[str]] = None,
|
485 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
486 |
+
negative_prompt: Union[str, List[str]] = None,
|
487 |
height: Optional[int] = None,
|
488 |
width: Optional[int] = None,
|
|
|
|
|
489 |
num_inference_steps: int = 8,
|
490 |
timesteps: List[int] = None,
|
|
|
491 |
guidance_scale: float = 3.5,
|
492 |
device: Optional[int] = None,
|
493 |
num_images_per_prompt: Optional[int] = 1,
|
|
|
495 |
latents: Optional[torch.FloatTensor] = None,
|
496 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
497 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
498 |
output_type: Optional[str] = "pil",
|
499 |
cfg: Optional[bool] = True,
|
500 |
return_dict: bool = True,
|
501 |
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
502 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
503 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
504 |
max_sequence_length: int = 512,
|
|
|
505 |
):
|
506 |
height = height or self.default_sample_size * self.vae_scale_factor
|
507 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
513 |
height,
|
514 |
width,
|
515 |
prompt_embeds=prompt_embeds,
|
|
|
516 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
517 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
518 |
max_sequence_length=max_sequence_length,
|
519 |
)
|
520 |
|
|
|
540 |
pooled_prompt_embeds,
|
541 |
text_ids,
|
542 |
negative_prompt_embeds,
|
543 |
+
negative_pooled_prompt_embeds
|
|
|
544 |
) = self.encode_prompt(
|
545 |
prompt=prompt,
|
546 |
prompt_2=prompt_2,
|
547 |
num_images_per_prompt=num_images_per_prompt,
|
548 |
max_sequence_length=max_sequence_length,
|
|
|
549 |
device=device,
|
550 |
negative_prompt=negative_prompt,
|
|
|
551 |
prompt_embeds=prompt_embeds,
|
|
|
552 |
pooled_prompt_embeds=pooled_prompt_embeds,
|
|
|
553 |
lora_scale=lora_scale,
|
554 |
)
|
555 |
|
|
|
596 |
for i, t in enumerate(timesteps):
|
597 |
if self.interrupt:
|
598 |
continue
|
599 |
+
|
600 |
+
# expand the latents if we are doing classifier free guidance
|
601 |
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
602 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
603 |
+
timestep = t.expand(latent_model_input.shape[0])
|
604 |
|
605 |
+
noise_pred = self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
hidden_states=latent_model_input,
|
607 |
timestep=timestep / 1000,
|
608 |
+
pooled_projections=pooled_prompt_embeds,
|
|
|
609 |
encoder_hidden_states=prompt_embeds,
|
610 |
txt_ids=text_ids,
|
611 |
img_ids=latent_image_ids,
|
612 |
joint_attention_kwargs=self.joint_attention_kwargs,
|
613 |
return_dict=False,
|
614 |
)[0]
|
615 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
616 |
if self.do_classifier_free_guidance:
|
617 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
618 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
619 |
+
|
|
|
620 |
# compute the previous noisy sample x_t -> x_t-1
|
621 |
latents_dtype = latents.dtype
|
622 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
623 |
+
|
624 |
if latents.dtype != latents_dtype:
|
625 |
if torch.backends.mps.is_available():
|
626 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
627 |
latents = latents.to(latents_dtype)
|
628 |
+
|
629 |
+
if callback_on_step_end is not None:
|
630 |
+
callback_kwargs = {}
|
631 |
+
for k in callback_on_step_end_tensor_inputs:
|
632 |
+
callback_kwargs[k] = locals()[k]
|
633 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
634 |
+
|
635 |
+
latents = callback_outputs.pop("latents", latents)
|
636 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
637 |
|
|
|
638 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
639 |
progress_bar.update()
|
640 |
+
|
641 |
+
if XLA_AVAILABLE:
|
642 |
+
xm.mark_step()
|
643 |
|
644 |
+
if output_type == "latent":
|
645 |
+
image = latents
|
646 |
+
|
647 |
+
else:
|
648 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
649 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
650 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
651 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
652 |
+
|
653 |
+
# Offload all models
|
654 |
self.maybe_free_model_hooks()
|
655 |
+
|
656 |
+
if not return_dict:
|
657 |
+
return (image,)
|
658 |
+
|
659 |
+
return FluxPipelineOutput(images=image)
|
|
|
|
|
|
|
|