smoothieAI commited on
Commit
0133547
·
verified ·
1 Parent(s): 246c970

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +874 -450
pipeline.py CHANGED
@@ -14,21 +14,18 @@
14
 
15
  import inspect
16
  from dataclasses import dataclass
17
- from typing import Any, Callable, Dict, List, Optional, Union
18
 
19
  import numpy as np
20
  import torch
21
- import torch.nn.functional as F
22
- from PIL import Image
23
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
24
 
 
25
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
26
  from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
- from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
28
  from diffusers.models.lora import adjust_lora_scale_text_encoder
29
  from diffusers.models.unet_motion_model import MotionAdapter
30
- from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
31
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
  from diffusers.schedulers import (
33
  DDIMScheduler,
34
  DPMSolverMultistepScheduler,
@@ -37,9 +34,27 @@ from diffusers.schedulers import (
37
  LMSDiscreteScheduler,
38
  PNDMScheduler,
39
  )
40
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
 
 
 
 
 
 
41
  from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
 
@@ -47,49 +62,72 @@ EXAMPLE_DOC_STRING = """
47
  Examples:
48
  ```py
49
  >>> import torch
50
- >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
51
- >>> from diffusers.pipelines import DiffusionPipeline
52
- >>> from diffusers.schedulers import DPMSolverMultistepScheduler
53
- >>> from PIL import Image
54
-
55
- >>> motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
56
- >>> adapter = MotionAdapter.from_pretrained(motion_id)
57
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
58
- >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
59
-
60
- >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
61
- >>> pipe = DiffusionPipeline.from_pretrained(
62
- ... model_id,
63
- ... motion_adapter=adapter,
64
- ... controlnet=controlnet,
65
- ... vae=vae,
66
- ... custom_pipeline="pipeline_animatediff_controlnet",
67
- ... ).to(device="cuda", dtype=torch.float16)
68
- >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
69
- ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
70
- ... )
71
- >>> pipe.enable_vae_slicing()
72
-
73
- >>> conditioning_frames = []
74
- >>> for i in range(1, 16 + 1):
75
- ... conditioning_frames.append(Image.open(f"frame_{i}.png"))
76
-
77
- >>> prompt = "astronaut in space, dancing"
78
- >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
79
- >>> result = pipe(
80
- ... prompt=prompt,
81
- ... negative_prompt=negative_prompt,
82
- ... width=512,
83
- ... height=768,
84
- ... conditioning_frames=conditioning_frames,
85
- ... num_inference_steps=12,
86
- ... ).frames[0]
87
-
88
  >>> from diffusers.utils import export_to_gif
89
- >>> export_to_gif(result.frames[0], "result.gif")
 
 
 
 
 
90
  ```
91
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def tensor2vid(video: torch.Tensor, processor, output_type="np"):
95
  # Based on:
@@ -107,23 +145,20 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
107
 
108
 
109
  @dataclass
110
- class AnimateDiffControlNetPipelineOutput(BaseOutput):
111
  frames: Union[torch.Tensor, np.ndarray]
112
 
113
 
114
- class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
115
  r"""
116
  Pipeline for text-to-video generation.
117
-
118
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
119
  implemented for all pipelines (downloading, saving, running on a particular device, etc.).
120
-
121
  The pipeline also inherits the following loading methods:
122
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
123
  - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
124
  - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
125
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
126
-
127
  Args:
128
  vae ([`AutoencoderKL`]):
129
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -140,9 +175,8 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
140
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
141
  """
142
 
143
- model_cpu_offload_seq = "text_encoder->unet->vae"
144
- _optional_components = ["feature_extractor", "image_encoder"]
145
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
146
 
147
  def __init__(
148
  self,
@@ -151,7 +185,6 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
151
  tokenizer: CLIPTokenizer,
152
  unet: UNet2DConditionModel,
153
  motion_adapter: MotionAdapter,
154
- controlnet: Union[ControlNetModel, MultiControlNetModel],
155
  scheduler: Union[
156
  DDIMScheduler,
157
  PNDMScheduler,
@@ -160,12 +193,23 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
160
  EulerAncestralDiscreteScheduler,
161
  DPMSolverMultistepScheduler,
162
  ],
 
163
  feature_extractor: Optional[CLIPImageProcessor] = None,
164
  image_encoder: Optional[CLIPVisionModelWithProjection] = None,
165
  ):
166
  super().__init__()
167
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
168
-
 
 
 
 
 
 
 
 
 
 
169
  self.register_modules(
170
  vae=vae,
171
  text_encoder=text_encoder,
@@ -183,6 +227,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
183
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
184
  )
185
 
 
 
 
186
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
187
  def encode_prompt(
188
  self,
@@ -198,7 +245,6 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
198
  ):
199
  r"""
200
  Encodes the prompt into text encoder hidden states.
201
-
202
  Args:
203
  prompt (`str` or `List[str]`, *optional*):
204
  prompt to be encoded
@@ -366,18 +412,29 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
366
  return prompt_embeds, negative_prompt_embeds
367
 
368
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
369
- def encode_image(self, image, device, num_images_per_prompt):
370
  dtype = next(self.image_encoder.parameters()).dtype
371
 
372
  if not isinstance(image, torch.Tensor):
373
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
374
 
375
  image = image.to(device=device, dtype=dtype)
376
- image_embeds = self.image_encoder(image).image_embeds
377
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
- uncond_image_embeds = torch.zeros_like(image_embeds)
380
- return image_embeds, uncond_image_embeds
381
 
382
  # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
383
  def decode_latents(self, latents):
@@ -439,12 +496,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
439
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
440
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
441
  r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
442
-
443
  The suffixes after the scaling factors represent the stages where they are being applied.
444
-
445
  Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
446
  that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
447
-
448
  Args:
449
  s1 (`float`):
450
  Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
@@ -493,10 +547,6 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
493
  prompt_embeds=None,
494
  negative_prompt_embeds=None,
495
  callback_on_step_end_tensor_inputs=None,
496
- image=None,
497
- controlnet_conditioning_scale=1.0,
498
- control_guidance_start=0.0,
499
- control_guidance_end=1.0,
500
  ):
501
  if height % 8 != 0 or width % 8 != 0:
502
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -539,147 +589,323 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
539
  f" {negative_prompt_embeds.shape}."
540
  )
541
 
542
- # `prompt` needs more sophisticated handling when there are multiple
543
- # conditionings.
544
- if isinstance(self.controlnet, MultiControlNetModel):
545
- if isinstance(prompt, list):
546
- logger.warning(
547
- f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
548
- " prompts. The conditionings will be fixed across the prompts."
549
- )
550
-
551
- # Check `image`
552
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
553
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
554
  )
555
- if (
556
- isinstance(self.controlnet, ControlNetModel)
557
- or is_compiled
558
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
559
- ):
560
- if isinstance(image, list):
561
- for image_ in image:
562
- self.check_image(image_, prompt, prompt_embeds)
563
- else:
564
- self.check_image(image, prompt, prompt_embeds)
565
- elif (
566
- isinstance(self.controlnet, MultiControlNetModel)
567
- or is_compiled
568
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
569
- ):
570
- if not isinstance(image, list):
571
- raise TypeError("For multiple controlnets: `image` must be type `list`")
572
-
573
- # When `image` is a nested list:
574
- # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
575
- elif any(isinstance(i, list) for i in image):
576
- raise ValueError("A single batch of multiple conditionings are supported at the moment.")
577
- elif len(image) != len(self.controlnet.nets):
578
- raise ValueError(
579
- f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
580
- )
581
 
582
- for control_ in image:
583
- for image_ in control_:
584
- self.check_image(image_, prompt, prompt_embeds)
585
  else:
586
- assert False
587
 
588
- # Check `controlnet_conditioning_scale`
589
- if (
590
- isinstance(self.controlnet, ControlNetModel)
591
- or is_compiled
592
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
593
- ):
594
- if not isinstance(controlnet_conditioning_scale, float):
595
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
596
- elif (
597
- isinstance(self.controlnet, MultiControlNetModel)
598
- or is_compiled
599
- and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
600
- ):
601
- if isinstance(controlnet_conditioning_scale, list):
602
- if any(isinstance(i, list) for i in controlnet_conditioning_scale):
603
- raise ValueError("A single batch of multiple conditionings are supported at the moment.")
604
- elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
605
- self.controlnet.nets
606
- ):
607
- raise ValueError(
608
- "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
609
- " the same length as the number of controlnets"
610
- )
611
  else:
612
- assert False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
 
614
- if not isinstance(control_guidance_start, (tuple, list)):
615
- control_guidance_start = [control_guidance_start]
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
- if not isinstance(control_guidance_end, (tuple, list)):
618
- control_guidance_end = [control_guidance_end]
 
619
 
620
- if len(control_guidance_start) != len(control_guidance_end):
 
 
 
 
 
 
 
 
 
621
  raise ValueError(
622
- f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
 
623
  )
624
 
625
- if isinstance(self.controlnet, MultiControlNetModel):
626
- if len(control_guidance_start) != len(self.controlnet.nets):
627
- raise ValueError(
628
- f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
629
- )
630
 
631
- for start, end in zip(control_guidance_start, control_guidance_end):
632
- if start >= end:
633
- raise ValueError(
634
- f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  )
636
- if start < 0.0:
637
- raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
638
- if end > 1.0:
639
- raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
640
-
641
- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
642
- def check_image(self, image, prompt, prompt_embeds):
643
- image_is_pil = isinstance(image, Image.Image)
644
- image_is_tensor = isinstance(image, torch.Tensor)
645
- image_is_np = isinstance(image, np.ndarray)
646
- image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image)
647
- image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
648
- image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
649
-
650
- if (
651
- not image_is_pil
652
- and not image_is_tensor
653
- and not image_is_np
654
- and not image_is_pil_list
655
- and not image_is_tensor_list
656
- and not image_is_np_list
657
- ):
658
- raise TypeError(
659
- f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
  )
661
 
662
- if image_is_pil:
663
- image_batch_size = 1
664
- else:
665
- image_batch_size = len(image)
 
 
 
 
 
 
 
 
 
 
666
 
667
- if prompt is not None and isinstance(prompt, str):
668
- prompt_batch_size = 1
669
- elif prompt is not None and isinstance(prompt, list):
670
- prompt_batch_size = len(prompt)
671
- elif prompt_embeds is not None:
672
- prompt_batch_size = prompt_embeds.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
673
 
674
- if image_batch_size != 1 and image_batch_size != prompt_batch_size:
675
  raise ValueError(
676
- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
 
677
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
- # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
680
- def prepare_latents(
681
- self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  shape = (
684
  batch_size,
685
  num_channels_latents,
@@ -687,6 +913,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
687
  height // self.vae_scale_factor,
688
  width // self.vae_scale_factor,
689
  )
 
690
  if isinstance(generator, list) and len(generator) != batch_size:
691
  raise ValueError(
692
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -694,16 +921,63 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
694
  )
695
 
696
  if latents is None:
697
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
698
- else:
699
- latents = latents.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
- # scale the initial noise by the standard deviation required by the scheduler
702
- latents = latents * self.scheduler.init_noise_sigma
703
- return latents
 
 
 
 
 
704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
706
- def prepare_image(
707
  self,
708
  image,
709
  width,
@@ -716,51 +990,32 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
716
  guess_mode=False,
717
  ):
718
  image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
719
- image_batch_size = image.shape[0]
 
720
 
721
- if image_batch_size == 1:
722
- repeat_by = batch_size
723
- else:
724
- # image batch size is the same as prompt batch size
725
- repeat_by = num_images_per_prompt
726
 
727
- image = image.repeat_interleave(repeat_by, dim=0)
728
 
729
  image = image.to(device=device, dtype=dtype)
730
 
731
- if do_classifier_free_guidance and not guess_mode:
732
- image = torch.cat([image] * 2)
733
 
734
  return image
735
-
736
- @property
737
- def guidance_scale(self):
738
- return self._guidance_scale
739
-
740
- @property
741
- def clip_skip(self):
742
- return self._clip_skip
743
-
744
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
745
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
746
- # corresponds to doing no classifier free guidance.
747
- @property
748
- def do_classifier_free_guidance(self):
749
- return self._guidance_scale > 1
750
-
751
- @property
752
- def cross_attention_kwargs(self):
753
- return self._cross_attention_kwargs
754
-
755
- @property
756
- def num_timesteps(self):
757
- return self._num_timesteps
758
-
759
  @torch.no_grad()
760
  def __call__(
761
  self,
762
  prompt: Union[str, List[str]] = None,
763
  num_frames: Optional[int] = 16,
 
 
 
764
  height: Optional[int] = None,
765
  width: Optional[int] = None,
766
  num_inference_steps: int = 50,
@@ -773,22 +1028,31 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
773
  prompt_embeds: Optional[torch.FloatTensor] = None,
774
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
775
  ip_adapter_image: Optional[PipelineImageInput] = None,
776
- conditioning_frames: Optional[List[PipelineImageInput]] = None,
777
  output_type: Optional[str] = "pil",
 
778
  return_dict: bool = True,
 
 
779
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
 
 
 
 
 
 
 
 
 
 
 
 
780
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
781
- guess_mode: bool = False,
782
  control_guidance_start: Union[float, List[float]] = 0.0,
783
  control_guidance_end: Union[float, List[float]] = 1.0,
784
- clip_skip: Optional[int] = None,
785
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
786
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
787
- **kwargs,
788
  ):
789
  r"""
790
  The call function to the pipeline for generation.
791
-
792
  Args:
793
  prompt (`str` or `List[str]`, *optional*):
794
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
@@ -825,83 +1089,49 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
825
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
826
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
827
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
828
- ip_adapter_image (`PipelineImageInput`, *optional*):
829
- Optional image input to work with IP Adapters.
830
- conditioning_frames (`List[PipelineImageInput]`, *optional*):
831
- The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets
832
- are specified, images must be passed as a list such that each element of the list can be correctly
833
- batched for input to a single ControlNet.
834
  output_type (`str`, *optional*, defaults to `"pil"`):
835
  The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
836
  `np.array`.
837
  return_dict (`bool`, *optional*, defaults to `True`):
838
  Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
839
  of a plain tuple.
 
 
 
 
 
 
840
  cross_attention_kwargs (`dict`, *optional*):
841
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
842
  [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
843
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
844
- The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
845
- to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
846
- the corresponding scale as a list.
847
- guess_mode (`bool`, *optional*, defaults to `False`):
848
- The ControlNet encoder tries to recognize the content of the input image even if you remove all
849
- prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
850
- control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
851
- The percentage of total steps at which the ControlNet starts applying.
852
- control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
853
- The percentage of total steps at which the ControlNet stops applying.
854
  clip_skip (`int`, *optional*):
855
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
856
  the output of the pre-final layer will be used for computing the prompt embeddings.
857
- allback_on_step_end (`Callable`, *optional*):
858
- A function that calls at the end of each denoising steps during the inference. The function is called
859
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
860
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
861
- `callback_on_step_end_tensor_inputs`.
862
- callback_on_step_end_tensor_inputs (`List`, *optional*):
863
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
864
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
865
- `._callback_tensor_inputs` attribute of your pipeine class.
866
-
867
  Examples:
868
-
869
  Returns:
870
  [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
871
  If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
872
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
873
  """
874
-
875
- callback = kwargs.pop("callback", None)
876
- callback_steps = kwargs.pop("callback_steps", None)
877
-
878
- if callback is not None:
879
- deprecate(
880
- "callback",
881
- "1.0.0",
882
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
883
- )
884
- if callback_steps is not None:
885
- deprecate(
886
- "callback_steps",
887
- "1.0.0",
888
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
889
- )
890
-
891
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
892
-
893
- # align format for control guidance
894
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
895
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
896
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
897
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
898
- elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
899
- mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
900
- control_guidance_start, control_guidance_end = (
901
- mult * [control_guidance_start],
902
- mult * [control_guidance_end],
903
- )
904
-
905
  # 0. Default height and width to unet
906
  height = height or self.unet.config.sample_size * self.vae_scale_factor
907
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -910,24 +1140,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
910
 
911
  # 1. Check inputs. Raise error if not correct
912
  self.check_inputs(
913
- prompt=prompt,
914
- height=height,
915
- width=width,
916
- callback_steps=callback_steps,
917
- negative_prompt=negative_prompt,
918
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
919
- prompt_embeds=prompt_embeds,
920
- negative_prompt_embeds=negative_prompt_embeds,
921
- image=conditioning_frames,
922
- controlnet_conditioning_scale=controlnet_conditioning_scale,
923
- control_guidance_start=control_guidance_start,
924
- control_guidance_end=control_guidance_end,
925
  )
926
 
927
- self._guidance_scale = guidance_scale
928
- self._clip_skip = clip_skip
929
- self._cross_attention_kwargs = cross_attention_kwargs
930
-
931
  # 2. Define call parameters
932
  if prompt is not None and isinstance(prompt, str):
933
  batch_size = 1
@@ -937,16 +1152,23 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
937
  batch_size = prompt_embeds.shape[0]
938
 
939
  device = self._execution_device
940
-
941
- if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
942
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
943
-
944
- global_pool_conditions = (
945
- controlnet.config.global_pool_conditions
946
- if isinstance(controlnet, ControlNetModel)
947
- else controlnet.nets[0].config.global_pool_conditions
948
- )
949
- guess_mode = guess_mode or global_pool_conditions
 
 
 
 
 
 
 
950
 
951
  # 3. Encode input prompt
952
  text_encoder_lora_scale = (
@@ -956,180 +1178,382 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
956
  prompt,
957
  device,
958
  num_videos_per_prompt,
959
- self.do_classifier_free_guidance,
960
  negative_prompt,
961
  prompt_embeds=prompt_embeds,
962
  negative_prompt_embeds=negative_prompt_embeds,
963
  lora_scale=text_encoder_lora_scale,
964
- clip_skip=self.clip_skip,
965
  )
966
  # For classifier free guidance, we need to do two forward passes.
967
  # Here we concatenate the unconditional and text embeddings into a single batch
968
  # to avoid doing two forward passes
969
- if self.do_classifier_free_guidance:
970
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
971
 
972
  if ip_adapter_image is not None:
973
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
974
- if self.do_classifier_free_guidance:
 
 
 
975
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
976
 
977
- if isinstance(controlnet, ControlNetModel):
978
- conditioning_frames = self.prepare_image(
979
- image=conditioning_frames,
980
- width=width,
981
- height=height,
982
- batch_size=batch_size * num_videos_per_prompt * num_frames,
983
- num_images_per_prompt=num_videos_per_prompt,
984
- device=device,
985
- dtype=controlnet.dtype,
986
- do_classifier_free_guidance=self.do_classifier_free_guidance,
987
- guess_mode=guess_mode,
988
- )
989
- elif isinstance(controlnet, MultiControlNetModel):
990
- cond_prepared_frames = []
991
- for frame_ in conditioning_frames:
992
- prepared_frame = self.prepare_image(
993
- image=frame_,
994
  width=width,
995
  height=height,
996
  batch_size=batch_size * num_videos_per_prompt * num_frames,
997
  num_images_per_prompt=num_videos_per_prompt,
998
  device=device,
999
  dtype=controlnet.dtype,
1000
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1001
  guess_mode=guess_mode,
1002
  )
1003
-
1004
- cond_prepared_frames.append(prepared_frame)
1005
-
1006
- conditioning_frames = cond_prepared_frames
1007
- else:
1008
- assert False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1009
 
1010
  # 4. Prepare timesteps
1011
  self.scheduler.set_timesteps(num_inference_steps, device=device)
1012
  timesteps = self.scheduler.timesteps
1013
- self._num_timesteps = len(timesteps)
1014
 
 
 
 
1015
  # 5. Prepare latent variables
1016
  num_channels_latents = self.unet.config.in_channels
1017
- latents = self.prepare_latents(
1018
- batch_size * num_videos_per_prompt,
1019
- num_channels_latents,
1020
- num_frames,
1021
- height,
1022
- width,
1023
- prompt_embeds.dtype,
1024
- device,
1025
- generator,
1026
- latents,
1027
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1028
 
1029
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1030
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1031
-
1032
- # 7. Add image embeds for IP-Adapter
1033
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1034
-
1035
  # 7.1 Create tensor stating which controlnets to keep
1036
- controlnet_keep = []
1037
- for i in range(len(timesteps)):
1038
- keeps = [
1039
- 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1040
- for s, e in zip(control_guidance_start, control_guidance_end)
1041
- ]
1042
- controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1043
-
1044
- print("############ START ANIMATEDIFF CONTZROLNET PIPELINE #############")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1045
  # Denoising loop
1046
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1047
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1048
  for i, t in enumerate(timesteps):
1049
- # expand the latents if we are doing classifier free guidance
1050
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1051
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1052
-
1053
- if guess_mode and self.do_classifier_free_guidance:
1054
- # Infer ControlNet only for the conditional batch.
1055
- control_model_input = latents
1056
- control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1057
- controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1058
- else:
1059
- control_model_input = latent_model_input
1060
- controlnet_prompt_embeds = prompt_embeds
1061
- controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
1062
-
1063
- if isinstance(controlnet_keep[i], list):
1064
- cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1065
- else:
1066
- controlnet_cond_scale = controlnet_conditioning_scale
1067
- if isinstance(controlnet_cond_scale, list):
1068
- controlnet_cond_scale = controlnet_cond_scale[0]
1069
- cond_scale = controlnet_cond_scale * controlnet_keep[i]
1070
-
1071
- print("-----------------------")
1072
- print("control_model_input.shape", control_model_input.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1073
 
1074
- control_model_input = torch.transpose(control_model_input, 1, 2)
1075
- control_model_input = control_model_input.reshape(
1076
- (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
1077
- )
1078
- print("prompt_embeds.shape", prompt_embeds.shape)
1079
- print("control_model_input.shape", control_model_input.shape)
1080
- print("controlnet_prompt_embeds.shape", controlnet_prompt_embeds.shape)
1081
- print("conditioning_frames.shape", conditioning_frames.shape)
1082
- print("cond_scale", cond_scale)
1083
- print("guess_mode", guess_mode)
1084
-
1085
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1086
- control_model_input,
1087
- t,
1088
- encoder_hidden_states=controlnet_prompt_embeds,
1089
- controlnet_cond=conditioning_frames,
1090
- conditioning_scale=cond_scale,
1091
- guess_mode=guess_mode,
1092
- return_dict=False,
1093
- )
1094
-
1095
- # predict the noise residual
1096
- noise_pred = self.unet(
1097
- latent_model_input,
1098
- t,
1099
- encoder_hidden_states=prompt_embeds,
1100
- cross_attention_kwargs=self.cross_attention_kwargs,
1101
- added_cond_kwargs=added_cond_kwargs,
1102
- down_block_additional_residuals=down_block_res_samples,
1103
- mid_block_additional_residual=mid_block_res_sample,
1104
- ).sample
1105
-
1106
  # perform guidance
1107
- if self.do_classifier_free_guidance:
1108
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 
 
 
1109
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1110
-
1111
  # compute the previous noisy sample x_t -> x_t-1
1112
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1113
-
1114
- if callback_on_step_end is not None:
1115
- callback_kwargs = {}
1116
- for k in callback_on_step_end_tensor_inputs:
1117
- callback_kwargs[k] = locals()[k]
1118
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1119
-
1120
- latents = callback_outputs.pop("latents", latents)
1121
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1122
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1123
-
1124
  # call the callback, if provided
1125
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1126
  progress_bar.update()
1127
  if callback is not None and i % callback_steps == 0:
1128
- callback(i, t, latents)
1129
-
1130
  if output_type == "latent":
1131
- return AnimateDiffControlNetPipelineOutput(frames=latents)
1132
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  # Post-processing
1134
  video_tensor = self.decode_latents(latents)
1135
 
@@ -1144,4 +1568,4 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
1144
  if not return_dict:
1145
  return (video,)
1146
 
1147
- return AnimateDiffControlNetPipelineOutput(frames=video)
 
14
 
15
  import inspect
16
  from dataclasses import dataclass
17
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
18
 
19
  import numpy as np
20
  import torch
 
 
21
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
22
 
23
+ # Updated to use absolute paths
24
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
25
  from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel, ControlNetModel
27
  from diffusers.models.lora import adjust_lora_scale_text_encoder
28
  from diffusers.models.unet_motion_model import MotionAdapter
 
 
29
  from diffusers.schedulers import (
30
  DDIMScheduler,
31
  DPMSolverMultistepScheduler,
 
34
  LMSDiscreteScheduler,
35
  PNDMScheduler,
36
  )
37
+ from diffusers.utils import (
38
+ USE_PEFT_BACKEND,
39
+ BaseOutput,
40
+ logging,
41
+ scale_lora_layers,
42
+ unscale_lora_layers,
43
+ )
44
  from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
45
 
46
+ # Added imports based on the working paths
47
+ from diffusers.models import ControlNetModel
48
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
49
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
50
+ from diffusers.utils import deprecate
51
+
52
+ import torchvision
53
+ import PIL
54
+ import PIL.Image
55
+ import math
56
+ import time
57
+
58
 
59
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
 
 
62
  Examples:
63
  ```py
64
  >>> import torch
65
+ >>> from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  >>> from diffusers.utils import export_to_gif
67
+ >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
68
+ >>> pipe = AnimateDiffPipeline.from_pretrained("frankjoshua/toonyou_beta6", motion_adapter=adapter)
69
+ >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False)
70
+ >>> output = pipe(prompt="A corgi walking in the park")
71
+ >>> frames = output.frames[0]
72
+ >>> export_to_gif(frames, "animation.gif")
73
  ```
74
  """
75
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
76
+ def retrieve_latents(
77
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
78
+ ):
79
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
80
+ return encoder_output.latent_dist.sample(generator)
81
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
82
+ return encoder_output.latent_dist.mode()
83
+ elif hasattr(encoder_output, "latents"):
84
+ return encoder_output.latents
85
+ else:
86
+ raise AttributeError("Could not access latents of provided encoder_output")
87
+
88
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
89
+ def retrieve_timesteps(
90
+ scheduler,
91
+ num_inference_steps: Optional[int] = None,
92
+ device: Optional[Union[str, torch.device]] = None,
93
+ timesteps: Optional[List[int]] = None,
94
+ **kwargs,
95
+ ):
96
+ """
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
 
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
105
+ `timesteps` must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
110
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
111
+ must be `None`.
112
+
113
+ Returns:
114
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
115
+ second element is the number of inference steps.
116
+ """
117
+ if timesteps is not None:
118
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
119
+ # if not accepts_timesteps:
120
+ # raise ValueError(
121
+ # f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
122
+ # f" timestep schedules. Please check whether you are using the correct scheduler."
123
+ # )
124
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
125
+ timesteps = scheduler.timesteps
126
+ num_inference_steps = len(timesteps)
127
+ else:
128
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ return timesteps, num_inference_steps
131
 
132
  def tensor2vid(video: torch.Tensor, processor, output_type="np"):
133
  # Based on:
 
145
 
146
 
147
  @dataclass
148
+ class AnimateDiffPipelineOutput(BaseOutput):
149
  frames: Union[torch.Tensor, np.ndarray]
150
 
151
 
152
+ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
153
  r"""
154
  Pipeline for text-to-video generation.
 
155
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
156
  implemented for all pipelines (downloading, saving, running on a particular device, etc.).
 
157
  The pipeline also inherits the following loading methods:
158
  - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
159
  - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
160
  - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
161
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
 
162
  Args:
163
  vae ([`AutoencoderKL`]):
164
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
 
175
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
176
  """
177
 
178
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
179
+ _optional_components = ["feature_extractor", "image_encoder","controlnet"]
 
180
 
181
  def __init__(
182
  self,
 
185
  tokenizer: CLIPTokenizer,
186
  unet: UNet2DConditionModel,
187
  motion_adapter: MotionAdapter,
 
188
  scheduler: Union[
189
  DDIMScheduler,
190
  PNDMScheduler,
 
193
  EulerAncestralDiscreteScheduler,
194
  DPMSolverMultistepScheduler,
195
  ],
196
+ controlnet: Optional[Union[ControlNetModel, MultiControlNetModel]]=None,
197
  feature_extractor: Optional[CLIPImageProcessor] = None,
198
  image_encoder: Optional[CLIPVisionModelWithProjection] = None,
199
  ):
200
  super().__init__()
201
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
202
+ if hasattr(self.pipe, "controlnet"):print("has controlnet")
203
+ if controlnet is None:
204
+ if hasattr(self, "controlnet"):delattr(self, "controlnet")
205
+
206
+ # print all the attributes
207
+ print("Attributes:")
208
+ for attr in dir(self):
209
+ print(attr)
210
+
211
+ print("contorlnet still exists:", hasattr(self, "controlnet"))
212
+
213
  self.register_modules(
214
  vae=vae,
215
  text_encoder=text_encoder,
 
227
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
228
  )
229
 
230
+ def load_motion_adapter(self,motion_adapter):
231
+ self.register_modules(motion_adapter=motion_adapter)
232
+
233
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
234
  def encode_prompt(
235
  self,
 
245
  ):
246
  r"""
247
  Encodes the prompt into text encoder hidden states.
 
248
  Args:
249
  prompt (`str` or `List[str]`, *optional*):
250
  prompt to be encoded
 
412
  return prompt_embeds, negative_prompt_embeds
413
 
414
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
415
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
416
  dtype = next(self.image_encoder.parameters()).dtype
417
 
418
  if not isinstance(image, torch.Tensor):
419
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
420
 
421
  image = image.to(device=device, dtype=dtype)
422
+ if output_hidden_states:
423
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
424
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
425
+ uncond_image_enc_hidden_states = self.image_encoder(
426
+ torch.zeros_like(image), output_hidden_states=True
427
+ ).hidden_states[-2]
428
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
429
+ num_images_per_prompt, dim=0
430
+ )
431
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
432
+ else:
433
+ image_embeds = self.image_encoder(image).image_embeds
434
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
435
+ uncond_image_embeds = torch.zeros_like(image_embeds)
436
 
437
+ return image_embeds, uncond_image_embeds
 
438
 
439
  # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
440
  def decode_latents(self, latents):
 
496
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
497
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
498
  r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
 
499
  The suffixes after the scaling factors represent the stages where they are being applied.
 
500
  Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
501
  that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
 
502
  Args:
503
  s1 (`float`):
504
  Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
 
547
  prompt_embeds=None,
548
  negative_prompt_embeds=None,
549
  callback_on_step_end_tensor_inputs=None,
 
 
 
 
550
  ):
551
  if height % 8 != 0 or width % 8 != 0:
552
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
 
589
  f" {negative_prompt_embeds.shape}."
590
  )
591
 
592
+ # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
593
+ def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
594
+ shape = (
595
+ batch_size,
596
+ num_channels_latents,
597
+ num_frames,
598
+ height // self.vae_scale_factor,
599
+ width // self.vae_scale_factor,
 
 
 
 
600
  )
601
+ if isinstance(generator, list) and len(generator) != batch_size:
602
+ raise ValueError(
603
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
604
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
605
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
607
+ if latents is None:
608
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
609
  else:
610
+ latents = latents.to(device)
611
 
612
+ # scale the initial noise by the standard deviation required by the scheduler
613
+ latents = latents * self.scheduler.init_noise_sigma
614
+ return latents
615
+
616
+ def prepare_latents_same_start(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, context_size=16, overlap=4, strength=0.5):
617
+ shape = (
618
+ batch_size,
619
+ num_channels_latents,
620
+ num_frames,
621
+ height // self.vae_scale_factor,
622
+ width // self.vae_scale_factor,
623
+ )
624
+ if isinstance(generator, list) and len(generator) != batch_size:
625
+ raise ValueError(
626
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
627
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
628
+ )
629
+
630
+ if latents is None:
631
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
 
632
  else:
633
+ latents = latents.to(device)
634
+
635
+ # make every (context_size-overlap) frames have the same noise
636
+ loop_size = context_size - overlap
637
+ loop_count = num_frames // loop_size
638
+ for i in range(loop_count):
639
+ # repeat the first frames noise for i*loop_size frame
640
+ # lerp the first frames noise
641
+ latents[:, :, i*loop_size:(i*loop_size)+overlap, :, :] = torch.lerp(latents[:, :, i*loop_size:(i*loop_size)+overlap, :, :], latents[:, :, 0:overlap, :, :], strength)
642
+
643
+ # scale the initial noise by the standard deviation required by the scheduler
644
+ latents = latents * self.scheduler.init_noise_sigma
645
+ return latents
646
+
647
+ def prepare_latents_consistent(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None,smooth_weight=0.5,smooth_steps=3):
648
+ shape = (
649
+ batch_size,
650
+ num_channels_latents,
651
+ num_frames,
652
+ height // self.vae_scale_factor,
653
+ width // self.vae_scale_factor,
654
+ )
655
+ if isinstance(generator, list) and len(generator) != batch_size:
656
+ raise ValueError(
657
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
658
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
659
+ )
660
+
661
+ if latents is None:
662
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
663
 
664
+ # blend each frame with the surrounding N frames making sure to wrap around at the end
665
+ for i in range(num_frames):
666
+ blended_latent = torch.zeros_like(latents[:, :, i])
667
+ for s in range(-smooth_steps, smooth_steps + 1):
668
+ if s == 0:
669
+ continue
670
+ frame_index = (i + s) % num_frames
671
+ weight = (smooth_steps - abs(s)) / smooth_steps
672
+ blended_latent += latents[:, :, frame_index] * weight
673
+ latents[:, :, i] = blended_latent / (2 * smooth_steps)
674
+
675
+ latents = torch.lerp(randn_tensor(shape, generator=generator, device=device, dtype=dtype),latents, smooth_weight)
676
+ else:
677
+ latents = latents.to(device)
678
 
679
+ # scale the initial noise by the standard deviation required by the scheduler
680
+ latents = latents * self.scheduler.init_noise_sigma
681
+ return latents
682
 
683
+ def prepare_motion_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator,
684
+ latents=None, x_velocity=0, y_velocity=0, scale_velocity=0):
685
+ shape = (
686
+ batch_size,
687
+ num_channels_latents,
688
+ num_frames,
689
+ height // self.vae_scale_factor,
690
+ width // self.vae_scale_factor,
691
+ )
692
+ if isinstance(generator, list) and len(generator) != batch_size:
693
  raise ValueError(
694
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
695
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
696
  )
697
 
698
+ if latents is None:
699
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
700
+ else:
701
+ latents = latents.to(device)
 
702
 
703
+ # scale the initial noise by the standard deviation required by the scheduler
704
+ latents = latents * self.scheduler.init_noise_sigma
705
+
706
+ for frame in range(num_frames):
707
+ x_offset = int(frame * x_velocity) # Convert to int
708
+ y_offset = int(frame * y_velocity) # Convert to int
709
+ scale_factor = 1 + frame * scale_velocity
710
+
711
+ # Apply offsets
712
+ latents[:, :, frame] = torch.roll(latents[:, :, frame], shifts=(x_offset,), dims=3) # x direction
713
+ latents[:, :, frame] = torch.roll(latents[:, :, frame], shifts=(y_offset,), dims=2) # y direction
714
+
715
+ # Apply scaling - This is a simple approach and might not be ideal for all applications
716
+ if scale_factor != 1:
717
+ scaled_size = (
718
+ int(latents.shape[3] * scale_factor),
719
+ int(latents.shape[4] * scale_factor)
720
  )
721
+ latents[:, :, frame] = torch.nn.functional.interpolate(
722
+ latents[:, :, frame].unsqueeze(0), size=scaled_size, mode='bilinear', align_corners=False
723
+ ).squeeze(0)
724
+
725
+ return latents
726
+
727
+ def generate_correlated_noise(self, latents, init_noise_correlation):
728
+ cloned_latents = latents.clone()
729
+ p = init_noise_correlation
730
+ flattened_latents = torch.flatten(cloned_latents)
731
+ noise = torch.randn_like(flattened_latents)
732
+ correlated_noise = flattened_latents * p + math.sqrt(1 - p**2) * noise
733
+
734
+ return correlated_noise.reshape(cloned_latents.shape)
735
+
736
+ def generate_correlated_latents(self, latents, init_noise_correlation):
737
+ cloned_latents = latents.clone()
738
+ for i in range(1, cloned_latents.shape[2]):
739
+ p = init_noise_correlation
740
+ flattened_latents = torch.flatten(cloned_latents[:, :, i])
741
+ prev_flattened_latents = torch.flatten(cloned_latents[:, :, i - 1])
742
+ correlated_latents = (prev_flattened_latents * p/math.sqrt((1+p**2))+flattened_latents * math.sqrt(1/(1 + p**2)))
743
+ cloned_latents[:, :, i] = correlated_latents.reshape(cloned_latents[:, :, i].shape)
744
+
745
+ return cloned_latents
746
+
747
+ def generate_correlated_latents_legacy(self, latents, init_noise_correlation):
748
+ cloned_latents = latents.clone()
749
+ for i in range(1, cloned_latents.shape[2]):
750
+ p = init_noise_correlation
751
+ flattened_latents = torch.flatten(cloned_latents[:, :, i])
752
+ prev_flattened_latents = torch.flatten(cloned_latents[:, :, i - 1])
753
+ correlated_latents = (
754
+ prev_flattened_latents * p
755
+ +
756
+ flattened_latents * math.sqrt(1 - p**2)
757
+ )
758
+ cloned_latents[:, :, i] = correlated_latents.reshape(
759
+ cloned_latents[:, :, i].shape
760
  )
761
 
762
+ return cloned_latents
763
+
764
+ def generate_mixed_noise(self, noise, init_noise_correlation):
765
+ shared_noise = torch.randn_like(noise[0, :, 0])
766
+ for b in range(noise.shape[0]):
767
+ for f in range(noise.shape[2]):
768
+ p = init_noise_correlation
769
+ flattened_latents = torch.flatten(noise[b, :, f])
770
+ shared_latents = torch.flatten(shared_noise)
771
+ correlated_latents = (
772
+ shared_latents * math.sqrt(p**2/(1+p**2)) +
773
+ flattened_latents * math.sqrt(1/(1+p**2))
774
+ )
775
+ noise[b, :, f] = correlated_latents.reshape(noise[b, :, f].shape)
776
 
777
+ return noise
778
+
779
+ def prepare_correlated_latents(
780
+ self,
781
+ init_image,
782
+ init_image_strength,
783
+ init_noise_correlation,
784
+ batch_size,
785
+ num_channels_latents,
786
+ video_length,
787
+ height,
788
+ width,
789
+ dtype,
790
+ device,
791
+ generator,
792
+ latents=None,
793
+ ):
794
+ shape = (
795
+ batch_size,
796
+ num_channels_latents,
797
+ video_length,
798
+ height // self.vae_scale_factor,
799
+ width // self.vae_scale_factor,
800
+ )
801
+
802
+ if init_image is not None:
803
+ start_image = ((torchvision.transforms.functional.pil_to_tensor(init_image))/ 255 )[:3, :, :].to("cuda").to(dtype).unsqueeze(0)
804
+ start_image = (
805
+ self.vae.encode(start_image.mul(2).sub(1))
806
+ .latent_dist.sample()
807
+ .view(1, 4, height // 8, width // 8)
808
+ * 0.18215
809
+ )
810
+ init_latents = start_image.unsqueeze(2).repeat(1, 1, video_length, 1, 1)
811
+ else:
812
+ init_latents = None
813
 
814
+ if isinstance(generator, list) and len(generator) != batch_size:
815
  raise ValueError(
816
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
817
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
818
  )
819
+ if latents is None:
820
+ rand_device = "cpu" if device.type == "mps" else device
821
+ if isinstance(generator, list):
822
+ shape = shape
823
+ # shape = (1,) + shape[1:]
824
+ # ignore init latents for batch model
825
+ latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)for i in range(batch_size)]
826
+ latents = torch.cat(latents, dim=0).to(device)
827
+ else:
828
+ if init_latents is not None:
829
+ offset = int(
830
+ init_image_strength * (len(self.scheduler.timesteps) - 1)
831
+ )
832
+ noise = torch.randn_like(init_latents)
833
+ noise = self.generate_correlated_latents(noise, init_noise_correlation)
834
+
835
+ # Eric - some black magic here
836
+ # We should be only adding the noise at timestep[offset], but I noticed that
837
+ # we get more motion and cooler motion if we add the noise at timestep[offset - 1]
838
+ # or offset - 2. However, this breaks the fewer timesteps there are, so let's interpolate
839
+ timesteps = self.scheduler.timesteps
840
+ average_timestep = None
841
+ if offset == 0:
842
+ average_timestep = timesteps[0]
843
+ elif offset == 1:
844
+ average_timestep = (
845
+ timesteps[offset - 1] * (1 - init_image_strength)
846
+ + timesteps[offset] * init_image_strength
847
+ )
848
+ else:
849
+ average_timestep = timesteps[offset - 1]
850
+
851
+ latents = self.scheduler.add_noise(
852
+ init_latents, noise, average_timestep.long()
853
+ )
854
+
855
+ latents = self.scheduler.add_noise(
856
+ latents, torch.randn_like(init_latents), timesteps[-2]
857
+ )
858
+ else:
859
+ latents = torch.randn(
860
+ shape, generator=generator, device=rand_device, dtype=dtype
861
+ ).to(device)
862
+ latents = self.generate_correlated_latents(
863
+ latents, init_noise_correlation
864
+ )
865
+ else:
866
+ if latents.shape != shape:
867
+ raise ValueError(
868
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}"
869
+ )
870
+ latents = latents.to(device)
871
 
872
+ # scale the initial noise by the standard deviation required by the scheduler
873
+ if init_latents is None:
874
+ latents = latents * self.scheduler.init_noise_sigma
875
+ # elif self.unet.trained_initial_frames and init_latents is not None:
876
+ # # we only want to use this as the first frame
877
+ # init_latents[:, :, 1:] = torch.zeros_like(init_latents[:, :, 1:])
878
+
879
+ latents = latents.to(device)
880
+ return latents, init_latents
881
+
882
+ def prepare_video_latents(
883
+ self,
884
+ video,
885
+ height,
886
+ width,
887
+ num_channels_latents,
888
+ batch_size,
889
+ timestep,
890
+ dtype,
891
+ device,
892
+ generator,
893
+ latents=None,
894
  ):
895
+ # video must be a list of list of images
896
+ # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
897
+ # as a list of images
898
+ if not isinstance(video[0], list):
899
+ video = [video]
900
+ if latents is None:
901
+ video = torch.cat(
902
+ [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
903
+ )
904
+ video = video.to(device=device, dtype=dtype)
905
+ num_frames = video.shape[1]
906
+ else:
907
+ num_frames = latents.shape[2]
908
+
909
  shape = (
910
  batch_size,
911
  num_channels_latents,
 
913
  height // self.vae_scale_factor,
914
  width // self.vae_scale_factor,
915
  )
916
+
917
  if isinstance(generator, list) and len(generator) != batch_size:
918
  raise ValueError(
919
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
 
921
  )
922
 
923
  if latents is None:
924
+ # make sure the VAE is in float32 mode, as it overflows in float16
925
+ if self.vae.config.force_upcast:
926
+ video = video.float()
927
+ self.vae.to(dtype=torch.float32)
928
+
929
+ if isinstance(generator, list):
930
+ if len(generator) != batch_size:
931
+ raise ValueError(
932
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
933
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
934
+ )
935
+
936
+ init_latents = [
937
+ retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
938
+ for i in range(batch_size)
939
+ ]
940
+ else:
941
+ init_latents = [
942
+ retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
943
+ ]
944
 
945
+ init_latents = torch.cat(init_latents, dim=0)
946
+
947
+ # restore vae to original dtype
948
+ if self.vae.config.force_upcast:
949
+ self.vae.to(dtype)
950
+
951
+ init_latents = init_latents.to(dtype)
952
+ init_latents = self.vae.config.scaling_factor * init_latents
953
 
954
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
955
+ # expand init_latents for batch_size
956
+ error_message = (
957
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
958
+ " images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
959
+ )
960
+ raise ValueError(error_message)
961
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
962
+ raise ValueError(
963
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
964
+ )
965
+ else:
966
+ init_latents = torch.cat([init_latents], dim=0)
967
+
968
+ noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
969
+ latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
970
+ else:
971
+ if shape != latents.shape:
972
+ # [B, C, F, H, W]
973
+ raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
974
+ latents = latents.to(device, dtype=dtype)
975
+
976
+ return latents
977
+
978
+
979
  # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
980
+ def prepare_control_frames(
981
  self,
982
  image,
983
  width,
 
990
  guess_mode=False,
991
  ):
992
  image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
993
+ # image_batch_size = image.shape[0]
994
+ image_batch_size = len(image)
995
 
996
+ # if image_batch_size == 1:
997
+ # repeat_by = batch_size
998
+ # else:
999
+ # # image batch size is the same as prompt batch size
1000
+ # repeat_by = num_images_per_prompt
1001
 
1002
+ # image = image.repeat_interleave(repeat_by, dim=0)
1003
 
1004
  image = image.to(device=device, dtype=dtype)
1005
 
1006
+ # if do_classifier_free_guidance and not guess_mode:
1007
+ # image = torch.cat([image] * 2)
1008
 
1009
  return image
1010
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  @torch.no_grad()
1012
  def __call__(
1013
  self,
1014
  prompt: Union[str, List[str]] = None,
1015
  num_frames: Optional[int] = 16,
1016
+ context_size=16,
1017
+ overlap=2,
1018
+ step=1,
1019
  height: Optional[int] = None,
1020
  width: Optional[int] = None,
1021
  num_inference_steps: int = 50,
 
1028
  prompt_embeds: Optional[torch.FloatTensor] = None,
1029
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1030
  ip_adapter_image: Optional[PipelineImageInput] = None,
 
1031
  output_type: Optional[str] = "pil",
1032
+ output_path: Optional[str] = None,
1033
  return_dict: bool = True,
1034
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1035
+ callback_steps: Optional[int] = 1,
1036
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1037
+ clip_skip: Optional[int] = None,
1038
+ x_velocity: Optional[float] = 0,
1039
+ y_velocity: Optional[float] = 0,
1040
+ scale_velocity: Optional[float] = 0,
1041
+ init_image: Optional[PipelineImageInput] = None,
1042
+ init_image_strength: Optional[float] = 1.0,
1043
+ init_noise_correlation: Optional[float] = 0.0,
1044
+ latent_mode: Optional[str] = "normal",
1045
+ smooth_weight: Optional[float] = 0.5,
1046
+ smooth_steps: Optional[int] = 3,
1047
+ initial_context_scale: Optional[float] = 1.0,
1048
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
1049
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
 
1050
  control_guidance_start: Union[float, List[float]] = 0.0,
1051
  control_guidance_end: Union[float, List[float]] = 1.0,
1052
+ guess_mode: bool = False,
 
 
 
1053
  ):
1054
  r"""
1055
  The call function to the pipeline for generation.
 
1056
  Args:
1057
  prompt (`str` or `List[str]`, *optional*):
1058
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
 
1089
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1090
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1091
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1092
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
 
 
 
 
 
1093
  output_type (`str`, *optional*, defaults to `"pil"`):
1094
  The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
1095
  `np.array`.
1096
  return_dict (`bool`, *optional*, defaults to `True`):
1097
  Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
1098
  of a plain tuple.
1099
+ callback (`Callable`, *optional*):
1100
+ A function that calls every `callback_steps` steps during inference. The function is called with the
1101
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1102
+ callback_steps (`int`, *optional*, defaults to 1):
1103
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
1104
+ every step.
1105
  cross_attention_kwargs (`dict`, *optional*):
1106
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1107
  [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
 
 
 
 
 
 
 
 
 
 
 
1108
  clip_skip (`int`, *optional*):
1109
  Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1110
  the output of the pre-final layer will be used for computing the prompt embeddings.
 
 
 
 
 
 
 
 
 
 
1111
  Examples:
 
1112
  Returns:
1113
  [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
1114
  If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
1115
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1116
  """
1117
+
1118
+ if self.controlnet != None:
1119
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1120
+
1121
+ # align format for control guidance
1122
+ control_end = control_guidance_end
1123
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1124
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1125
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1126
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1127
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1128
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1129
+ control_guidance_start, control_guidance_end = (
1130
+ mult * [control_guidance_start],
1131
+ mult * [control_guidance_end],
1132
+ )
1133
+
1134
+
 
 
 
 
 
 
 
 
 
 
 
 
 
1135
  # 0. Default height and width to unet
1136
  height = height or self.unet.config.sample_size * self.vae_scale_factor
1137
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
1140
 
1141
  # 1. Check inputs. Raise error if not correct
1142
  self.check_inputs(
1143
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
 
 
 
 
 
 
 
 
 
 
 
1144
  )
1145
 
 
 
 
 
1146
  # 2. Define call parameters
1147
  if prompt is not None and isinstance(prompt, str):
1148
  batch_size = 1
 
1152
  batch_size = prompt_embeds.shape[0]
1153
 
1154
  device = self._execution_device
1155
+
1156
+ if self.controlnet != None:
1157
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1158
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1159
+
1160
+ global_pool_conditions = (
1161
+ controlnet.config.global_pool_conditions
1162
+ if isinstance(controlnet, ControlNetModel)
1163
+ else controlnet.nets[0].config.global_pool_conditions
1164
+ )
1165
+ guess_mode = guess_mode or global_pool_conditions
1166
+
1167
+
1168
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1169
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1170
+ # corresponds to doing no classifier free guidance.
1171
+ do_classifier_free_guidance = guidance_scale > 1.0
1172
 
1173
  # 3. Encode input prompt
1174
  text_encoder_lora_scale = (
 
1178
  prompt,
1179
  device,
1180
  num_videos_per_prompt,
1181
+ do_classifier_free_guidance,
1182
  negative_prompt,
1183
  prompt_embeds=prompt_embeds,
1184
  negative_prompt_embeds=negative_prompt_embeds,
1185
  lora_scale=text_encoder_lora_scale,
1186
+ clip_skip=clip_skip,
1187
  )
1188
  # For classifier free guidance, we need to do two forward passes.
1189
  # Here we concatenate the unconditional and text embeddings into a single batch
1190
  # to avoid doing two forward passes
1191
+ if do_classifier_free_guidance:
1192
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1193
 
1194
  if ip_adapter_image is not None:
1195
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
1196
+ image_embeds, negative_image_embeds = self.encode_image(
1197
+ ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
1198
+ )
1199
+ if do_classifier_free_guidance:
1200
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
1201
 
1202
+ if self.controlnet != None:
1203
+ if isinstance(controlnet, ControlNetModel):
1204
+ # conditioning_frames = self.prepare_image(
1205
+ # image=conditioning_frames,
1206
+ # width=width,
1207
+ # height=height,
1208
+ # batch_size=batch_size * num_videos_per_prompt * num_frames,
1209
+ # num_images_per_prompt=num_videos_per_prompt,
1210
+ # device=device,
1211
+ # dtype=controlnet.dtype,
1212
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
1213
+ # guess_mode=guess_mode,
1214
+ # )
1215
+ conditioning_frames = self.prepare_control_frames(
1216
+ image=conditioning_frames,
 
 
1217
  width=width,
1218
  height=height,
1219
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1220
  num_images_per_prompt=num_videos_per_prompt,
1221
  device=device,
1222
  dtype=controlnet.dtype,
1223
+ do_classifier_free_guidance=do_classifier_free_guidance,
1224
  guess_mode=guess_mode,
1225
  )
1226
+
1227
+ elif isinstance(controlnet, MultiControlNetModel):
1228
+ cond_prepared_frames = []
1229
+ for frame_ in conditioning_frames:
1230
+ # prepared_frame = self.prepare_image(
1231
+ # image=frame_,
1232
+ # width=width,
1233
+ # height=height,
1234
+ # batch_size=batch_size * num_videos_per_prompt * num_frames,
1235
+ # num_images_per_prompt=num_videos_per_prompt,
1236
+ # device=device,
1237
+ # dtype=controlnet.dtype,
1238
+ # do_classifier_free_guidance=self.do_classifier_free_guidance,
1239
+ # guess_mode=guess_mode,
1240
+ # )
1241
+
1242
+ prepared_frame = self.prepare_control_frames(
1243
+ image=frame_,
1244
+ width=width,
1245
+ height=height,
1246
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
1247
+ num_images_per_prompt=num_videos_per_prompt,
1248
+ device=device,
1249
+ dtype=controlnet.dtype,
1250
+ do_classifier_free_guidance=do_classifier_free_guidance,
1251
+ guess_mode=guess_mode,
1252
+ )
1253
+
1254
+ cond_prepared_frames.append(prepared_frame)
1255
+
1256
+ conditioning_frames = cond_prepared_frames
1257
+ else:
1258
+ assert False
1259
 
1260
  # 4. Prepare timesteps
1261
  self.scheduler.set_timesteps(num_inference_steps, device=device)
1262
  timesteps = self.scheduler.timesteps
1263
+
1264
 
1265
+ # round num frames to the nearest multiple of context size - overlap
1266
+ num_frames = (num_frames // (context_size - overlap)) * (context_size - overlap)
1267
+
1268
  # 5. Prepare latent variables
1269
  num_channels_latents = self.unet.config.in_channels
1270
+ if(latent_mode == "normal"):
1271
+ latents = self.prepare_latents(
1272
+ batch_size * num_videos_per_prompt,
1273
+ num_channels_latents,
1274
+ num_frames,
1275
+ height,
1276
+ width,
1277
+ prompt_embeds.dtype,
1278
+ device,
1279
+ generator,
1280
+ latents,
1281
+ )
1282
+ if(latent_mode == "same_start"):
1283
+ latents = self.prepare_latents_same_start(
1284
+ batch_size * num_videos_per_prompt,
1285
+ num_channels_latents,
1286
+ num_frames,
1287
+ height,
1288
+ width,
1289
+ prompt_embeds.dtype,
1290
+ device,
1291
+ generator,
1292
+ latents,
1293
+ context_size=context_size,
1294
+ overlap=overlap,
1295
+ strength=init_image_strength,
1296
+ )
1297
+ elif(latent_mode == "motion"):
1298
+ latents = self.prepare_motion_latents(
1299
+ batch_size * num_videos_per_prompt,
1300
+ num_channels_latents,
1301
+ num_frames,
1302
+ height,
1303
+ width,
1304
+ prompt_embeds.dtype,
1305
+ device,
1306
+ generator,
1307
+ latents,
1308
+ x_velocity=x_velocity,
1309
+ y_velocity=y_velocity,
1310
+ scale_velocity=scale_velocity,
1311
+ )
1312
+ elif(latent_mode == "correlated"):
1313
+ latents, init_latents = self.prepare_correlated_latents(
1314
+ init_image,
1315
+ init_image_strength,
1316
+ init_noise_correlation,
1317
+ batch_size,
1318
+ num_channels_latents,
1319
+ num_frames,
1320
+ height,
1321
+ width,
1322
+ prompt_embeds.dtype,
1323
+ device,
1324
+ generator,
1325
+ )
1326
+ elif(latent_mode == "consistent"):
1327
+ latents = self.prepare_latents_consistent(
1328
+ batch_size * num_videos_per_prompt,
1329
+ num_channels_latents,
1330
+ num_frames,
1331
+ height,
1332
+ width,
1333
+ prompt_embeds.dtype,
1334
+ device,
1335
+ generator,
1336
+ latents,
1337
+ smooth_weight,
1338
+ smooth_steps,
1339
+ )
1340
+ elif(latent_mode == "video"):
1341
+ # 4. Prepare timesteps
1342
+ # timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1343
+ # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, init_image_strength, device)
1344
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
1345
+ self._num_timesteps = len(timesteps)
1346
+ num_channels_latents = self.unet.config.in_channels
1347
+ latents = self.prepare_video_latents(
1348
+ video=init_image,
1349
+ height=height,
1350
+ width=width,
1351
+ num_channels_latents=num_channels_latents,
1352
+ batch_size=batch_size * num_videos_per_prompt,
1353
+ timestep=latent_timestep,
1354
+ dtype=prompt_embeds.dtype,
1355
+ device=device,
1356
+ generator=generator,
1357
+ latents=latents,
1358
+ )
1359
+
1360
 
1361
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1362
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1363
+
1364
+ # 7 Add image embeds for IP-Adapter
1365
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1366
+
1367
  # 7.1 Create tensor stating which controlnets to keep
1368
+ if self.controlnet != None:
1369
+ controlnet_keep = []
1370
+ for i in range(len(timesteps)):
1371
+ keeps = [
1372
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1373
+ for s, e in zip(control_guidance_start, control_guidance_end)
1374
+ ]
1375
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1376
+
1377
+ # divide the initial latents into context groups
1378
+
1379
+ def context_scheduler(context_size, overlap, offset, total_frames, total_timesteps):
1380
+ # Calculate the number of context groups based on frame count and context size
1381
+ number_of_context_groups = (total_frames // (context_size - overlap))
1382
+ # Initialize a list to store context indexes for all timesteps
1383
+ all_context_indexes = []
1384
+ # Iterate over each timestep
1385
+ for timestep in range(total_timesteps):
1386
+ # Initialize a list to store groups of context indexes for this timestep
1387
+ timestep_context_groups = []
1388
+ # Iterate over each context group
1389
+ for group_index in range(number_of_context_groups):
1390
+ # Initialize a list to store context indexes for this group
1391
+ context_group_indexes = []
1392
+ # Iterate over each index in the context group
1393
+ local_context_size = context_size
1394
+ if timestep <= 1:
1395
+ local_context_size = int(context_size * initial_context_scale)
1396
+ for index in range(local_context_size):
1397
+ # if its the first timestep, spread the indexes out evenly over the full frame range, offset by the group index
1398
+ frame_index = (group_index * (local_context_size - overlap)) + (offset * timestep) + index
1399
+ # If frame index exceeds total frames, wrap around
1400
+ if frame_index >= total_frames:
1401
+ frame_index %= total_frames
1402
+ # Add the frame index to the group's list
1403
+ context_group_indexes.append(frame_index)
1404
+ # Add the group's indexes to the timestep's list
1405
+ timestep_context_groups.append(context_group_indexes)
1406
+ # Add the timestep's context groups to the overall list
1407
+ all_context_indexes.append(timestep_context_groups)
1408
+ return all_context_indexes
1409
+
1410
+ context_indexes = context_scheduler(context_size, overlap, step, num_frames, len(timesteps))
1411
+
1412
  # Denoising loop
1413
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1414
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
1415
  for i, t in enumerate(timesteps):
1416
+ noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1417
+ noise_pred_text_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
1418
+ latent_counter = torch.zeros(num_frames).to(device).to(dtype=torch.float16)
1419
+
1420
+ # foreach context group seperately denoise the current timestep
1421
+ for context_group in range(len(context_indexes[i])):
1422
+ # calculate to current indexes, considering overlapa
1423
+ current_context_indexes = context_indexes[i][context_group]
1424
+
1425
+ # select the relevent context from the latents
1426
+ current_context_latents = latents[:, :, current_context_indexes, :, :]
1427
+
1428
+ # expand the latents if we are doing classifier free guidance
1429
+ latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1430
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1431
+
1432
+ if self.controlnet != None and i < int(control_end*num_inference_steps):
1433
+
1434
+ torch.cuda.synchronize() # Synchronize GPU
1435
+ control_start = time.time()
1436
+
1437
+ current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
1438
+ current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
1439
+
1440
+
1441
+ if guess_mode and self.do_classifier_free_guidance:
1442
+ # Infer ControlNet only for the conditional batch.
1443
+ control_model_input = latents
1444
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1445
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1446
+ else:
1447
+ control_model_input = latent_model_input
1448
+ controlnet_prompt_embeds = prompt_embeds
1449
+ controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(len(current_context_indexes), dim=0)
1450
+
1451
+ if isinstance(controlnet_keep[i], list):
1452
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1453
+ else:
1454
+ controlnet_cond_scale = controlnet_conditioning_scale
1455
+ if isinstance(controlnet_cond_scale, list):
1456
+ controlnet_cond_scale = controlnet_cond_scale[0]
1457
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1458
+
1459
+
1460
+ control_model_input = torch.transpose(control_model_input, 1, 2)
1461
+ control_model_input = control_model_input.reshape(
1462
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
1463
+ )
1464
+
1465
+
1466
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1467
+ control_model_input,
1468
+ t,
1469
+ encoder_hidden_states=controlnet_prompt_embeds,
1470
+ controlnet_cond=current_context_conditioning_frames,
1471
+ conditioning_scale=cond_scale,
1472
+ guess_mode=guess_mode,
1473
+ return_dict=False,
1474
+ )
1475
+
1476
+ unet_start = time.time()
1477
+ # predict the noise residual with the added controlnet residuals
1478
+ noise_pred = self.unet(
1479
+ latent_model_input,
1480
+ t,
1481
+ encoder_hidden_states=prompt_embeds,
1482
+ cross_attention_kwargs=cross_attention_kwargs,
1483
+ added_cond_kwargs=added_cond_kwargs,
1484
+ down_block_additional_residuals=down_block_res_samples,
1485
+ mid_block_additional_residual=mid_block_res_sample,
1486
+ ).sample
1487
+
1488
+ else:
1489
+ # predict the noise residual without contorlnet
1490
+ torch.cuda.synchronize()
1491
+ unet_start = time.time()
1492
+ noise_pred = self.unet(
1493
+ latent_model_input,
1494
+ t,
1495
+ encoder_hidden_states=prompt_embeds,
1496
+ cross_attention_kwargs=cross_attention_kwargs,
1497
+ added_cond_kwargs=added_cond_kwargs,
1498
+ ).sample
1499
+
1500
+ if do_classifier_free_guidance:
1501
+ # Start timing for overall guidance process
1502
+ torch.cuda.synchronize() # Synchronize GPU before starting timing
1503
+ start_guidance_time = time.time()
1504
+
1505
+ # Timing for chunk operation
1506
+ torch.cuda.synchronize() # Synchronize GPU before chunking
1507
+ time_chunk_start = time.time()
1508
+
1509
+ noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
1510
+
1511
+ # Timing for batch addition and latent counter increment
1512
+ torch.cuda.synchronize() # Synchronize GPU before batch addition
1513
+ time_batch_addition_start = time.time()
1514
+
1515
+ # Perform batch addition
1516
+ noise_pred_uncond_sum[..., current_context_indexes, :, :] += noise_pred_uncond
1517
+ noise_pred_text_sum[..., current_context_indexes, :, :] += noise_pred_text
1518
+ latent_counter[current_context_indexes] += 1
1519
+
1520
+ # set the step index to the current batch
1521
+ self.scheduler._step_index = i
1522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1523
  # perform guidance
1524
+ if do_classifier_free_guidance:
1525
+ latent_counter = latent_counter.reshape(1, 1, num_frames, 1, 1)
1526
+ noise_pred_uncond = noise_pred_uncond_sum / latent_counter
1527
+ noise_pred_text = noise_pred_text_sum / latent_counter
1528
+
1529
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1530
+
1531
  # compute the previous noisy sample x_t -> x_t-1
1532
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1533
+
 
 
 
 
 
 
 
 
 
 
1534
  # call the callback, if provided
1535
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1536
  progress_bar.update()
1537
  if callback is not None and i % callback_steps == 0:
1538
+ callback(i, t, None)
1539
+
1540
  if output_type == "latent":
1541
+ return AnimateDiffPipelineOutput(frames=latents)
1542
+
1543
+ # save frames
1544
+ if output_path is not None:
1545
+ output_batch_size = 2 # prevents out of memory errors with large videos
1546
+ num_digits = output_path.count('#') # count the number of '#' characters
1547
+ frame_format = output_path.replace('#' * num_digits, '{:0' + str(num_digits) + 'd}')
1548
+ for batch in range((num_frames + output_batch_size - 1) // output_batch_size):
1549
+ start_id = batch * output_batch_size
1550
+ end_id = min((batch + 1) * output_batch_size, num_frames)
1551
+ video_tensor = self.decode_latents(latents[:, :, start_id:end_id, :, :])
1552
+ video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
1553
+ for f_id, frame in enumerate(video[0]):
1554
+ frame.save(frame_format.format(start_id + f_id))
1555
+ return output_path
1556
+
1557
  # Post-processing
1558
  video_tensor = self.decode_latents(latents)
1559
 
 
1568
  if not return_dict:
1569
  return (video,)
1570
 
1571
+ return AnimateDiffPipelineOutput(frames=video)