Add remote code for Diffusers

#1
by hlky HF staff - opened
README.md CHANGED
@@ -105,27 +105,11 @@ Go to [ComfyUI-Animemory-Loader](https://github.com/animEEEmpire/ComfyUI-Animemo
105
 
106
  3.Diffusers inference.
107
 
108
- - The pipeline has not been merged yet. Please use the following code to setup the environment.
109
- ```shell
110
- git clone https://github.com/huggingface/diffusers.git
111
- git clone https://github.com/animEEEmpire/diffusers_animemory
112
- cp diffusers_animemory/* diffusers -r
113
-
114
- # Method 1: Re-install diffusers. (Recommended)
115
- cd diffusers
116
- pip install .
117
-
118
- # Method 2: Call it locally. Change `YOUR_PATH` to the directory where you just cloned `diffusers` and `diffusers_animemory`.
119
- import sys
120
- sys.path.insert(0, 'YOUR_PATH/diffusers/src')
121
- ```
122
- - And then, you can use the following code to generate images.
123
-
124
  ```python
125
- from diffusers import AniMemoryPipeline
126
  import torch
127
 
128
- pipe = AniMemoryPipeline.from_pretrained("animEEEmpire/AniMemory-alpha", torch_dtype=torch.bfloat16)
129
  pipe.to("cuda")
130
 
131
  prompt = "一只凶恶的狼,猩红的眼神,在午夜咆哮,月光皎洁"
 
105
 
106
  3.Diffusers inference.
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  ```python
109
+ from diffusers import DiffusionPipeline
110
  import torch
111
 
112
+ pipe = DiffusionPipeline.from_pretrained("animEEEmpire/AniMemory-alpha", trust_remote_code=True, torch_dtype=torch.bfloat16)
113
  pipe.to("cuda")
114
 
115
  prompt = "一只凶恶的狼,猩红的眼神,在午夜咆哮,月光皎洁"
model_index.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_class_name": "AniMemoryPipeLine",
3
  "_diffusers_version": "0.32.0.dev0",
4
  "feature_extractor": [
5
  null,
@@ -11,15 +11,15 @@
11
  null
12
  ],
13
  "scheduler": [
14
- "diffusers",
15
  "EulerAncestralDiscreteXPredScheduler"
16
  ],
17
  "text_encoder": [
18
- "animemory",
19
  "AniMemoryT5"
20
  ],
21
  "text_encoder_2": [
22
- "animemory",
23
  "AniMemoryAltCLip"
24
  ],
25
  "tokenizer": [
@@ -35,7 +35,7 @@
35
  "UNet2DConditionModel"
36
  ],
37
  "vae": [
38
- "animemory",
39
  "MoVQ"
40
  ]
41
  }
 
1
  {
2
+ "_class_name": ["pipeline_animemory", "AniMemoryPipeline"],
3
  "_diffusers_version": "0.32.0.dev0",
4
  "feature_extractor": [
5
  null,
 
11
  null
12
  ],
13
  "scheduler": [
14
+ "scheduling_euler_ancestral_discrete_x_pred",
15
  "EulerAncestralDiscreteXPredScheduler"
16
  ],
17
  "text_encoder": [
18
+ "animemory_t5",
19
  "AniMemoryT5"
20
  ],
21
  "text_encoder_2": [
22
+ "animemory_altclip",
23
  "AniMemoryAltCLip"
24
  ],
25
  "tokenizer": [
 
35
  "UNet2DConditionModel"
36
  ],
37
  "vae": [
38
+ "modeling_movq",
39
  "MoVQ"
40
  ]
41
  }
pipeline_animemory.py ADDED
@@ -0,0 +1,1771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 AniMemory Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ import numpy as np
17
+ import PIL.Image
18
+
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from transformers import (
24
+ CLIPImageProcessor,
25
+ CLIPVisionModelWithProjection,
26
+ XLMRobertaTokenizerFast,
27
+ )
28
+
29
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
30
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
31
+ from diffusers.loaders import (
32
+ FromSingleFileMixin,
33
+ IPAdapterMixin,
34
+ StableDiffusionXLLoraLoaderMixin,
35
+ TextualInversionLoaderMixin,
36
+ )
37
+ from diffusers.models import ImageProjection, UNet2DConditionModel
38
+ from diffusers.models.attention_processor import (
39
+ AttnProcessor2_0,
40
+ FusedAttnProcessor2_0,
41
+ XFormersAttnProcessor,
42
+ )
43
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
44
+ from diffusers.utils import (
45
+ USE_PEFT_BACKEND,
46
+ deprecate,
47
+ is_torch_xla_available,
48
+ logging,
49
+ replace_example_docstring,
50
+ scale_lora_layers,
51
+ unscale_lora_layers,
52
+ )
53
+ from diffusers.utils.torch_utils import randn_tensor
54
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
55
+
56
+ from diffusers.utils import BaseOutput
57
+
58
+
59
+ if is_torch_xla_available():
60
+ import torch_xla.core.xla_model as xm
61
+
62
+ XLA_AVAILABLE = True
63
+ else:
64
+ XLA_AVAILABLE = False
65
+
66
+
67
+ @dataclass
68
+ class AniMemoryPipelineOutput(BaseOutput):
69
+ """
70
+ Output class for Stable Diffusion pipelines.
71
+
72
+ Args:
73
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
74
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
75
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
76
+ """
77
+
78
+ images: Union[List[PIL.Image.Image], np.ndarray]
79
+
80
+
81
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
82
+
83
+ # TODO: update prompt case
84
+ EXAMPLE_DOC_STRING = """
85
+ Examples:
86
+ ```py
87
+ >>> import torch
88
+ >>> from diffusers import AniMemoryPipeline
89
+
90
+ >>> pipe = AniMemoryPipeline.from_pretrained("animEEEmpire/AniMemory-alpha", torch_dtype=torch.bfloat16)
91
+ >>> pipe = pipe.to("cuda")
92
+
93
+ >>> prompt = "一只凶恶的狼,猩红的眼神,在午夜咆哮,月光皎洁"
94
+ >>> negative_prompt = "nsfw, worst quality, low quality, normal quality, low resolution, monochrome, blurry, wrong, Mutated hands and fingers, text, ugly faces, twisted, jpeg artifacts, watermark, low contrast, realistic"
95
+ >>> image = pipe(
96
+ ... prompt=prompt,
97
+ ... negative_prompt=negative_prompt,
98
+ ... num_inference_steps=40,
99
+ ... height=1024,
100
+ ... width=1024,
101
+ ... guidance_scale=6.0,
102
+ ... ).images[0]
103
+ >>> image.save("output.png")
104
+ ```
105
+ """
106
+
107
+
108
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
109
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
110
+ r"""
111
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
112
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
113
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
114
+
115
+ Args:
116
+ noise_cfg (`torch.Tensor`):
117
+ The predicted noise tensor for the guided diffusion process.
118
+ noise_pred_text (`torch.Tensor`):
119
+ The predicted noise tensor for the text-guided diffusion process.
120
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
121
+ A rescale factor applied to the noise predictions.
122
+
123
+ Returns:
124
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
125
+ """
126
+ std_text = noise_pred_text.std(
127
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
128
+ )
129
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
130
+ # rescale the results from guidance (fixes overexposure)
131
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
132
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
133
+ noise_cfg = (
134
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
135
+ )
136
+ return noise_cfg
137
+
138
+
139
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
140
+ def retrieve_timesteps(
141
+ scheduler,
142
+ num_inference_steps: Optional[int] = None,
143
+ device: Optional[Union[str, torch.device]] = None,
144
+ timesteps: Optional[List[int]] = None,
145
+ sigmas: Optional[List[float]] = None,
146
+ **kwargs,
147
+ ):
148
+ r"""
149
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
150
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
151
+
152
+ Args:
153
+ scheduler (`SchedulerMixin`):
154
+ The scheduler to get timesteps from.
155
+ num_inference_steps (`int`):
156
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
157
+ must be `None`.
158
+ device (`str` or `torch.device`, *optional*):
159
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
160
+ timesteps (`List[int]`, *optional*):
161
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
162
+ `num_inference_steps` and `sigmas` must be `None`.
163
+ sigmas (`List[float]`, *optional*):
164
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
165
+ `num_inference_steps` and `timesteps` must be `None`.
166
+
167
+ Returns:
168
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
169
+ second element is the number of inference steps.
170
+ """
171
+ if timesteps is not None and sigmas is not None:
172
+ raise ValueError(
173
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
174
+ )
175
+ if timesteps is not None:
176
+ accepts_timesteps = "timesteps" in set(
177
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
178
+ )
179
+ if not accepts_timesteps:
180
+ raise ValueError(
181
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
182
+ f" timestep schedules. Please check whether you are using the correct scheduler."
183
+ )
184
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
185
+ timesteps = scheduler.timesteps
186
+ num_inference_steps = len(timesteps)
187
+ elif sigmas is not None:
188
+ accept_sigmas = "sigmas" in set(
189
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
190
+ )
191
+ if not accept_sigmas:
192
+ raise ValueError(
193
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
194
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
195
+ )
196
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
197
+ timesteps = scheduler.timesteps
198
+ num_inference_steps = len(timesteps)
199
+ else:
200
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
201
+ timesteps = scheduler.timesteps
202
+ return timesteps, num_inference_steps
203
+
204
+
205
+ def split_input_ids(
206
+ input_ids,
207
+ attention_mask,
208
+ start,
209
+ model_max_length,
210
+ bos_token_id,
211
+ eos_token_id,
212
+ pad_token_id,
213
+ ):
214
+ iids_list = []
215
+ mask_list = []
216
+ if start > 0:
217
+ cur_input_ids = input_ids[start - 1 :]
218
+ cur_input_ids[0] = bos_token_id
219
+ if attention_mask is not None:
220
+ cur_attention_mask = attention_mask[start - 1 :]
221
+ cur_attention_mask[0] = 1
222
+ else:
223
+ cur_input_ids = input_ids
224
+ if attention_mask is not None:
225
+ cur_attention_mask = attention_mask
226
+ n = len(cur_input_ids)
227
+
228
+ for i in range(1, n - model_max_length + 2, model_max_length - 2):
229
+ ids_chunk = (
230
+ cur_input_ids[0].unsqueeze(0),
231
+ cur_input_ids[i : i + model_max_length - 2],
232
+ cur_input_ids[-1].unsqueeze(0),
233
+ )
234
+ ids_chunk = torch.cat(ids_chunk)
235
+ if attention_mask is not None:
236
+ mask_chunk = (
237
+ cur_attention_mask[0].unsqueeze(0),
238
+ cur_attention_mask[i : i + model_max_length - 2],
239
+ cur_attention_mask[-1].unsqueeze(0),
240
+ )
241
+ mask_chunk = torch.cat(mask_chunk)
242
+
243
+ if ids_chunk[-2] != eos_token_id and ids_chunk[-2] != pad_token_id:
244
+ ids_chunk[-1] = eos_token_id
245
+ if attention_mask is not None:
246
+ mask_chunk[-1] = 1
247
+ if ids_chunk[1] == pad_token_id:
248
+ ids_chunk[1] = eos_token_id
249
+ if attention_mask is not None:
250
+ mask_chunk[1] = 1
251
+
252
+ iids_list.append(ids_chunk)
253
+ if attention_mask is not None:
254
+ mask_list.append(mask_chunk)
255
+
256
+ return iids_list, mask_list if len(mask_list) > 0 else None
257
+
258
+
259
+ # Modified from [library.train_util.get_input_ids](https://github.com/kohya-ss/sd-scripts/blob/e5ac09574928ec02fba5fe78267764d26bb7faa6/library/train_util.py#L795)
260
+ def get_input_ids(
261
+ caption,
262
+ tokenizer,
263
+ tokenizer_max_length,
264
+ dense_caption_split_method,
265
+ chunk,
266
+ punctuation_ids,
267
+ ):
268
+ prompt_tokens = tokenizer(
269
+ caption,
270
+ max_length=tokenizer_max_length,
271
+ padding="max_length",
272
+ truncation=True,
273
+ return_tensors="pt",
274
+ )
275
+ input_ids = prompt_tokens["input_ids"].squeeze(0)
276
+ attention_mask = prompt_tokens["attention_mask"].squeeze(0)
277
+
278
+ if not chunk:
279
+ return input_ids[None, ...], attention_mask[None, ...]
280
+
281
+ iids_list = []
282
+ mask_list = []
283
+
284
+ if dense_caption_split_method == "length_split":
285
+ iids_list, mask_list = split_input_ids(
286
+ input_ids,
287
+ attention_mask,
288
+ 0,
289
+ tokenizer.model_max_length,
290
+ tokenizer.bos_token_id,
291
+ tokenizer.eos_token_id,
292
+ tokenizer.pad_token_id,
293
+ )
294
+ elif dense_caption_split_method == "punctuation_split":
295
+ can_split_tensor = torch.zeros_like(input_ids)
296
+ for punctuation_id in punctuation_ids:
297
+ can_split_tensor = torch.logical_or(
298
+ can_split_tensor, input_ids == punctuation_id
299
+ )
300
+ can_split_index = (
301
+ [0]
302
+ + [i[0] for i in torch.nonzero(can_split_tensor).tolist()]
303
+ + [len(input_ids) - 1]
304
+ )
305
+ start = 1
306
+ end = 1
307
+
308
+ new_can_split_index = []
309
+ for i in range(len(can_split_index) - 1):
310
+ pre = can_split_index[i]
311
+ new_can_split_index.append(pre)
312
+ nxt = can_split_index[i + 1]
313
+ cur = pre + tokenizer.model_max_length - 2
314
+ while cur < nxt:
315
+ new_can_split_index.append(cur)
316
+ cur = cur + tokenizer.model_max_length - 2
317
+ new_can_split_index.append(can_split_index[-1])
318
+ can_split_index = new_can_split_index
319
+
320
+ for i in can_split_index:
321
+ if i - start + 1 > tokenizer.model_max_length - 2:
322
+ if end == start:
323
+ end = start + (tokenizer.model_max_length - 2)
324
+ ids_chunk = torch.tensor(
325
+ [tokenizer.pad_token_id] * tokenizer.model_max_length,
326
+ dtype=torch.int64,
327
+ )
328
+ ids_chunk[0] = tokenizer.bos_token_id
329
+ ids_chunk[1 : 1 + end - start] = input_ids[start:end]
330
+ ids_chunk[1 + end - start] = input_ids[-1]
331
+ mask_chunk = torch.zeros(tokenizer.model_max_length).to(torch.int64)
332
+ mask_chunk[0] = 1
333
+ mask_chunk[1 : 1 + end - start] = attention_mask[start:end]
334
+ mask_chunk[1 + end - start] = attention_mask[-1]
335
+ if ids_chunk[1] == tokenizer.pad_token_id:
336
+ ids_chunk[1] = tokenizer.eos_token_id
337
+ mask_chunk[1] = 1
338
+ if tokenizer.eos_token_id not in ids_chunk:
339
+ ids_chunk[1 + end - start] = tokenizer.eos_token_id
340
+ mask_chunk[1 + end - start] = 1
341
+ iids_list.append(ids_chunk)
342
+ mask_list.append(mask_chunk)
343
+ if len(iids_list) == 3:
344
+ break
345
+ start = end
346
+ end = i + 1
347
+
348
+ if len(iids_list) == 0:
349
+ iids_list, mask_list = split_input_ids(
350
+ input_ids,
351
+ attention_mask,
352
+ 0,
353
+ tokenizer.model_max_length,
354
+ tokenizer.bos_token_id,
355
+ tokenizer.eos_token_id,
356
+ tokenizer.pad_token_id,
357
+ )
358
+ elif len(iids_list) == 1:
359
+ iids_list1, mask_list1 = split_input_ids(
360
+ input_ids,
361
+ attention_mask,
362
+ start,
363
+ tokenizer.model_max_length,
364
+ tokenizer.bos_token_id,
365
+ tokenizer.eos_token_id,
366
+ tokenizer.pad_token_id,
367
+ )
368
+ iids_list = (iids_list + iids_list1)[:3]
369
+ mask_list = (mask_list + mask_list1)[:3]
370
+ elif len(iids_list) == 2:
371
+ iids_list1, mask_list1 = split_input_ids(
372
+ input_ids,
373
+ attention_mask,
374
+ start,
375
+ tokenizer.model_max_length,
376
+ tokenizer.bos_token_id,
377
+ tokenizer.eos_token_id,
378
+ tokenizer.pad_token_id,
379
+ )
380
+ iids_list = (iids_list + iids_list1)[:3]
381
+ mask_list = (mask_list + mask_list1)[:3]
382
+ else:
383
+ raise NotImplementedError
384
+
385
+ input_ids = torch.stack(iids_list)
386
+ attention_mask = torch.stack(mask_list)
387
+
388
+ return input_ids, attention_mask
389
+
390
+
391
+ class AniMemoryPipeline(
392
+ DiffusionPipeline,
393
+ StableDiffusionMixin,
394
+ FromSingleFileMixin,
395
+ StableDiffusionXLLoraLoaderMixin,
396
+ TextualInversionLoaderMixin,
397
+ IPAdapterMixin,
398
+ ):
399
+ # TODO: review
400
+ r"""
401
+ Pipeline for text-to-image generation using Stable Diffusion XL.
402
+
403
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
404
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
405
+
406
+ The pipeline also inherits the following loading methods:
407
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
408
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
409
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
410
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
411
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
412
+
413
+ Args:
414
+ vae ([`MoVQ`]):
415
+ Variational Auto-Encoder (VAE) Model. AniMemory uses
416
+ [MoVQ](https://github.com/ai-forever/Kandinsky-3/blob/main/kandinsky3/movq.py)
417
+ text_encoder ([`AniMemoryT5`]):
418
+ Frozen text-encoder. AniMemory builds based on
419
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel).
420
+ text_encoder_2 ([`AniMemoryAltCLip`]):
421
+ Second frozen text-encoder. AniMemory builds based on
422
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
423
+ tokenizer (`XLMRobertaTokenizerFast`):
424
+ Tokenizer of class
425
+ [XLMRobertaTokenizerFast](https://huggingface.co/docs/transformers/v4.46.3/en/model_doc/xlm-roberta#transformers.XLMRobertaTokenizerFast).
426
+ tokenizer_2 (`XLMRobertaTokenizerFast`):
427
+ Second Tokenizer of class
428
+ [XLMRobertaTokenizerFast](https://huggingface.co/docs/transformers/v4.46.3/en/model_doc/xlm-roberta#transformers.XLMRobertaTokenizerFast).
429
+ unet ([`UNet2DConditionModel`]):
430
+ Conditional U-Net architecture to denoise the encoded image latents.
431
+ scheduler ([`EulerAncestralDiscreteXPredScheduler`]):
432
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
433
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
434
+ Whether the negative prompt embeddings shall be forced to always be set to 0.
435
+ """
436
+
437
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
438
+ _optional_components = [
439
+ "tokenizer",
440
+ "tokenizer_2",
441
+ "text_encoder",
442
+ "text_encoder_2",
443
+ "image_encoder",
444
+ "feature_extractor",
445
+ ]
446
+ _callback_tensor_inputs = [
447
+ "latents",
448
+ "prompt_embeds",
449
+ "negative_prompt_embeds",
450
+ "add_text_embeds",
451
+ "add_time_ids",
452
+ "negative_pooled_prompt_embeds",
453
+ "negative_add_time_ids",
454
+ ]
455
+
456
+ def __init__(
457
+ self,
458
+ vae: "MoVQ", # type: ignore
459
+ text_encoder: "AniMemoryT5", # type: ignore
460
+ text_encoder_2: "AniMemoryAltCLip", # type: ignore
461
+ tokenizer: XLMRobertaTokenizerFast,
462
+ tokenizer_2: XLMRobertaTokenizerFast,
463
+ unet: UNet2DConditionModel,
464
+ scheduler: "EulerAncestralDiscreteXPredScheduler", # type: ignore
465
+ image_encoder: CLIPVisionModelWithProjection = None,
466
+ feature_extractor: CLIPImageProcessor = None,
467
+ force_zeros_for_empty_prompt: bool = True,
468
+ ):
469
+ super().__init__()
470
+
471
+ self.register_modules(
472
+ vae=vae,
473
+ text_encoder=text_encoder,
474
+ text_encoder_2=text_encoder_2,
475
+ tokenizer=tokenizer,
476
+ tokenizer_2=tokenizer_2,
477
+ unet=unet,
478
+ scheduler=scheduler,
479
+ image_encoder=image_encoder,
480
+ feature_extractor=feature_extractor,
481
+ )
482
+ self.register_to_config(
483
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
484
+ )
485
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
486
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
487
+
488
+ self.default_sample_size = self.unet.config.sample_size
489
+
490
+ self.unet.time_proj.downscale_freq_shift = 1
491
+
492
+ self.scheduler.config.clip_sample = False
493
+ self.scheduler.config.timestep_spacing = "linspace"
494
+ self.scheduler.config.prediction_type = "sample"
495
+ self.scheduler.rescale_betas_zero_snr()
496
+
497
+ def encode_prompt(
498
+ self,
499
+ prompt: str,
500
+ prompt_2: Optional[str] = None,
501
+ device: Optional[torch.device] = None,
502
+ num_images_per_prompt: int = 1,
503
+ do_classifier_free_guidance: bool = True,
504
+ negative_prompt: Optional[str] = None,
505
+ negative_prompt_2: Optional[str] = None,
506
+ prompt_embeds: Optional[torch.Tensor] = None,
507
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
508
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
509
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
510
+ lora_scale: Optional[float] = None,
511
+ clip_skip: Optional[int] = None,
512
+ ):
513
+ r"""
514
+ Encodes the prompt into text encoder hidden states.
515
+
516
+ Args:
517
+ prompt (`str` or `List[str]`, *optional*):
518
+ prompt to be encoded
519
+ prompt_2 (`str` or `List[str]`, *optional*):
520
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
521
+ used in both text-encoders
522
+ device: (`torch.device`):
523
+ torch device
524
+ num_images_per_prompt (`int`):
525
+ number of images that should be generated per prompt
526
+ do_classifier_free_guidance (`bool`):
527
+ whether to use classifier free guidance or not
528
+ negative_prompt (`str` or `List[str]`, *optional*):
529
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
530
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
531
+ less than `1`).
532
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
533
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
534
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
535
+ prompt_embeds (`torch.Tensor`, *optional*):
536
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
537
+ provided, text embeddings will be generated from `prompt` input argument.
538
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
539
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
540
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
541
+ argument.
542
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
543
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
544
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
545
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
546
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
547
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
548
+ input argument.
549
+ lora_scale (`float`, *optional*):
550
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
551
+ clip_skip (`int`, *optional*):
552
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
553
+ the output of the pre-final layer will be used for computing the prompt embeddings.
554
+ """
555
+ if device is None:
556
+ device = self._execution_device
557
+
558
+ # set lora scale so that monkey patched LoRA
559
+ # function of text encoder can correctly access it
560
+ if lora_scale is not None and isinstance(
561
+ self, StableDiffusionXLLoraLoaderMixin
562
+ ):
563
+ self._lora_scale = lora_scale
564
+
565
+ # dynamically adjust the LoRA scale
566
+ if self.text_encoder is not None:
567
+ if not USE_PEFT_BACKEND:
568
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
569
+ else:
570
+ scale_lora_layers(self.text_encoder, lora_scale)
571
+
572
+ if self.text_encoder_2 is not None:
573
+ if not USE_PEFT_BACKEND:
574
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
575
+ else:
576
+ scale_lora_layers(self.text_encoder_2, lora_scale)
577
+
578
+ prompt = [prompt] if isinstance(prompt, str) else prompt
579
+
580
+ if prompt is not None:
581
+ batch_size = len(prompt)
582
+ else:
583
+ batch_size = prompt_embeds.shape[0]
584
+
585
+ # Define tokenizers and text encoders
586
+ tokenizers = (
587
+ [self.tokenizer, self.tokenizer_2]
588
+ if self.tokenizer is not None
589
+ else [self.tokenizer_2]
590
+ )
591
+ text_encoders = (
592
+ [self.text_encoder, self.text_encoder_2]
593
+ if self.text_encoder is not None
594
+ else [self.text_encoder_2]
595
+ )
596
+
597
+ punctuation_ids = [
598
+ [5, 4, 74, 32, 38, 4730, 30, 4, 74, 32, 38, 4730],
599
+ [5, 4, 74, 32, 38, 4730, 30, 4, 74, 32, 38, 4730],
600
+ ]
601
+ max_token_length = 227
602
+
603
+ if prompt_embeds is None:
604
+ prompt_2 = prompt_2 or prompt
605
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
606
+
607
+ # textual inversion: process multi-vector tokens if necessary
608
+ prompt_embeds_list = []
609
+ prompts = [prompt, prompt_2]
610
+ text_encoder_idx = 0
611
+ for prompt, tokenizer, text_encoder in zip(
612
+ prompts, tokenizers, text_encoders
613
+ ):
614
+ text_input_ids, attention_mask = get_input_ids(
615
+ prompt,
616
+ tokenizers[text_encoder_idx],
617
+ max_token_length,
618
+ "punctuation_split",
619
+ False if text_encoder_idx == 0 else True,
620
+ punctuation_ids[text_encoder_idx],
621
+ )
622
+
623
+ tk_len = text_input_ids.shape[-1]
624
+ text_input_ids = text_input_ids.reshape((-1, tk_len))
625
+ attention_mask = attention_mask.reshape((-1, tk_len))
626
+
627
+ prompt_embeds, pooled_output = text_encoder(
628
+ text_input_ids.to(device), attention_mask.to(device)
629
+ )
630
+
631
+ if text_encoder_idx == 1:
632
+ tmp_ids = text_input_ids.reshape(-1, 3, text_input_ids.shape[-1])
633
+ _, n2, tk_len2 = tmp_ids.size()
634
+ prompt_embeds = prompt_embeds.reshape(
635
+ (-1, n2 * tk_len2, prompt_embeds.shape[-1])
636
+ )
637
+ if n2 > 1:
638
+ states_list = [prompt_embeds[:, 0].unsqueeze(1)]
639
+ for i in range(
640
+ 1,
641
+ max_token_length,
642
+ tokenizers[text_encoder_idx].model_max_length,
643
+ ):
644
+ states_list.append(
645
+ prompt_embeds[
646
+ :,
647
+ i : i
648
+ + tokenizers[text_encoder_idx].model_max_length
649
+ - 2,
650
+ ]
651
+ )
652
+ states_list.append(prompt_embeds[:, -1].unsqueeze(1))
653
+ prompt_embeds = torch.cat(states_list, dim=1)
654
+
655
+ pooled_prompt_embeds = pooled_output[::n2]
656
+
657
+ prompt_embeds_list.append(prompt_embeds)
658
+ text_encoder_idx += 1
659
+
660
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
661
+
662
+ # get unconditional embeddings for classifier free guidance
663
+ zero_out_negative_prompt = (
664
+ negative_prompt is None and self.config.force_zeros_for_empty_prompt
665
+ )
666
+ if (
667
+ do_classifier_free_guidance
668
+ and negative_prompt_embeds is None
669
+ and zero_out_negative_prompt
670
+ ):
671
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
672
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
673
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
674
+ negative_prompt = negative_prompt or ""
675
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
676
+
677
+ negative_prompt = (
678
+ batch_size * [negative_prompt]
679
+ if isinstance(negative_prompt, str)
680
+ else negative_prompt
681
+ )
682
+ negative_prompt_2 = (
683
+ batch_size * [negative_prompt_2]
684
+ if isinstance(negative_prompt_2, str)
685
+ else negative_prompt_2
686
+ )
687
+
688
+ uncond_tokens: List[str]
689
+ if prompt is not None and type(prompt) is not type(negative_prompt):
690
+ raise TypeError(
691
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
692
+ f" {type(prompt)}."
693
+ )
694
+ elif batch_size != len(negative_prompt):
695
+ raise ValueError(
696
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
697
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
698
+ " the batch size of `prompt`."
699
+ )
700
+ else:
701
+ uncond_tokens = [negative_prompt, negative_prompt_2]
702
+
703
+ negative_prompt_embeds_list = []
704
+ text_encoder_idx = 0
705
+ for negative_prompt, tokenizer, text_encoder in zip(
706
+ uncond_tokens, tokenizers, text_encoders
707
+ ):
708
+ if isinstance(self, TextualInversionLoaderMixin):
709
+ negative_prompt = self.maybe_convert_prompt(
710
+ negative_prompt, tokenizer
711
+ )
712
+
713
+ negative_text_input_ids, negative_attention_mask = get_input_ids(
714
+ negative_prompt,
715
+ tokenizers[text_encoder_idx],
716
+ max_token_length,
717
+ "punctuation_split",
718
+ False if text_encoder_idx == 0 else True,
719
+ punctuation_ids[text_encoder_idx],
720
+ )
721
+
722
+ tk_len = negative_text_input_ids.shape[-1]
723
+ negative_text_input_ids = negative_text_input_ids.reshape((-1, tk_len))
724
+ negative_attention_mask = negative_attention_mask.reshape((-1, tk_len))
725
+
726
+ negative_prompt_embeds, negative_pooled_ouput = text_encoder(
727
+ negative_text_input_ids.to(device),
728
+ negative_attention_mask.to(device),
729
+ )
730
+
731
+ if text_encoder_idx == 1:
732
+ negative_tmp_ids = negative_text_input_ids.reshape(
733
+ -1, 3, negative_text_input_ids.shape[-1]
734
+ )
735
+ _, n2, tk_len2 = negative_tmp_ids.size()
736
+ negative_prompt_embeds = negative_prompt_embeds.reshape(
737
+ (-1, n2 * tk_len2, negative_prompt_embeds.shape[-1])
738
+ )
739
+ if n2 > 1:
740
+ states_list = [negative_prompt_embeds[:, 0].unsqueeze(1)]
741
+ for i in range(
742
+ 1,
743
+ max_token_length,
744
+ tokenizers[text_encoder_idx].model_max_length,
745
+ ):
746
+ states_list.append(
747
+ negative_prompt_embeds[
748
+ :,
749
+ i : i
750
+ + tokenizers[text_encoder_idx].model_max_length
751
+ - 2,
752
+ ]
753
+ )
754
+ states_list.append(negative_prompt_embeds[:, -1].unsqueeze(1))
755
+ negative_prompt_embeds = torch.cat(states_list, dim=1)
756
+ negative_pooled_prompt_embeds = negative_pooled_ouput[::n2]
757
+
758
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
759
+ text_encoder_idx += 1
760
+
761
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
762
+
763
+ if self.text_encoder_2 is not None:
764
+ prompt_embeds = prompt_embeds.to(
765
+ dtype=self.text_encoder_2.dtype, device=device
766
+ )
767
+ else:
768
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
769
+
770
+ bs_embed, seq_len, _ = prompt_embeds.shape
771
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
772
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
773
+ prompt_embeds = prompt_embeds.view(
774
+ bs_embed * num_images_per_prompt, seq_len, -1
775
+ )
776
+
777
+ if do_classifier_free_guidance:
778
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
779
+ seq_len = negative_prompt_embeds.shape[1]
780
+
781
+ if self.text_encoder_2 is not None:
782
+ negative_prompt_embeds = negative_prompt_embeds.to(
783
+ dtype=self.text_encoder_2.dtype, device=device
784
+ )
785
+ else:
786
+ negative_prompt_embeds = negative_prompt_embeds.to(
787
+ dtype=self.unet.dtype, device=device
788
+ )
789
+
790
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
791
+ 1, num_images_per_prompt, 1
792
+ )
793
+ negative_prompt_embeds = negative_prompt_embeds.view(
794
+ batch_size * num_images_per_prompt, seq_len, -1
795
+ )
796
+
797
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(
798
+ 1, num_images_per_prompt
799
+ ).view(bs_embed * num_images_per_prompt, -1)
800
+ if do_classifier_free_guidance:
801
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
802
+ 1, num_images_per_prompt
803
+ ).view(bs_embed * num_images_per_prompt, -1)
804
+
805
+ if self.text_encoder is not None:
806
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
807
+ # Retrieve the original scale by scaling back the LoRA layers
808
+ unscale_lora_layers(self.text_encoder, lora_scale)
809
+
810
+ if self.text_encoder_2 is not None:
811
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
812
+ # Retrieve the original scale by scaling back the LoRA layers
813
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
814
+ # breakpoint()
815
+ return (
816
+ prompt_embeds,
817
+ negative_prompt_embeds,
818
+ pooled_prompt_embeds,
819
+ negative_pooled_prompt_embeds,
820
+ )
821
+
822
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
823
+ def encode_image(
824
+ self, image, device, num_images_per_prompt, output_hidden_states=None
825
+ ):
826
+ dtype = next(self.image_encoder.parameters()).dtype
827
+
828
+ if not isinstance(image, torch.Tensor):
829
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
830
+
831
+ image = image.to(device=device, dtype=dtype)
832
+ if output_hidden_states:
833
+ image_enc_hidden_states = self.image_encoder(
834
+ image, output_hidden_states=True
835
+ ).hidden_states[-2]
836
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
837
+ num_images_per_prompt, dim=0
838
+ )
839
+ uncond_image_enc_hidden_states = self.image_encoder(
840
+ torch.zeros_like(image), output_hidden_states=True
841
+ ).hidden_states[-2]
842
+ uncond_image_enc_hidden_states = (
843
+ uncond_image_enc_hidden_states.repeat_interleave(
844
+ num_images_per_prompt, dim=0
845
+ )
846
+ )
847
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
848
+ else:
849
+ image_embeds = self.image_encoder(image).image_embeds
850
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
851
+ uncond_image_embeds = torch.zeros_like(image_embeds)
852
+
853
+ return image_embeds, uncond_image_embeds
854
+
855
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
856
+ def prepare_ip_adapter_image_embeds(
857
+ self,
858
+ ip_adapter_image,
859
+ ip_adapter_image_embeds,
860
+ device,
861
+ num_images_per_prompt,
862
+ do_classifier_free_guidance,
863
+ ):
864
+ image_embeds = []
865
+ if do_classifier_free_guidance:
866
+ negative_image_embeds = []
867
+ if ip_adapter_image_embeds is None:
868
+ if not isinstance(ip_adapter_image, list):
869
+ ip_adapter_image = [ip_adapter_image]
870
+
871
+ if len(ip_adapter_image) != len(
872
+ self.unet.encoder_hid_proj.image_projection_layers
873
+ ):
874
+ raise ValueError(
875
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
876
+ )
877
+
878
+ for single_ip_adapter_image, image_proj_layer in zip(
879
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
880
+ ):
881
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
882
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
883
+ single_ip_adapter_image, device, 1, output_hidden_state
884
+ )
885
+
886
+ image_embeds.append(single_image_embeds[None, :])
887
+ if do_classifier_free_guidance:
888
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
889
+ else:
890
+ for single_image_embeds in ip_adapter_image_embeds:
891
+ if do_classifier_free_guidance:
892
+ (
893
+ single_negative_image_embeds,
894
+ single_image_embeds,
895
+ ) = single_image_embeds.chunk(2)
896
+ negative_image_embeds.append(single_negative_image_embeds)
897
+ image_embeds.append(single_image_embeds)
898
+
899
+ ip_adapter_image_embeds = []
900
+ for i, single_image_embeds in enumerate(image_embeds):
901
+ single_image_embeds = torch.cat(
902
+ [single_image_embeds] * num_images_per_prompt, dim=0
903
+ )
904
+ if do_classifier_free_guidance:
905
+ single_negative_image_embeds = torch.cat(
906
+ [negative_image_embeds[i]] * num_images_per_prompt, dim=0
907
+ )
908
+ single_image_embeds = torch.cat(
909
+ [single_negative_image_embeds, single_image_embeds], dim=0
910
+ )
911
+
912
+ single_image_embeds = single_image_embeds.to(device=device)
913
+ ip_adapter_image_embeds.append(single_image_embeds)
914
+
915
+ return ip_adapter_image_embeds
916
+
917
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
918
+ def prepare_extra_step_kwargs(self, generator, eta):
919
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
920
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
921
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
922
+ # and should be between [0, 1]
923
+
924
+ accepts_eta = "eta" in set(
925
+ inspect.signature(self.scheduler.step).parameters.keys()
926
+ )
927
+ extra_step_kwargs = {}
928
+ if accepts_eta:
929
+ extra_step_kwargs["eta"] = eta
930
+
931
+ # check if the scheduler accepts generator
932
+ accepts_generator = "generator" in set(
933
+ inspect.signature(self.scheduler.step).parameters.keys()
934
+ )
935
+ if accepts_generator:
936
+ extra_step_kwargs["generator"] = generator
937
+ return extra_step_kwargs
938
+
939
+ def check_inputs(
940
+ self,
941
+ prompt,
942
+ prompt_2,
943
+ height,
944
+ width,
945
+ callback_steps,
946
+ negative_prompt=None,
947
+ negative_prompt_2=None,
948
+ prompt_embeds=None,
949
+ negative_prompt_embeds=None,
950
+ pooled_prompt_embeds=None,
951
+ negative_pooled_prompt_embeds=None,
952
+ ip_adapter_image=None,
953
+ ip_adapter_image_embeds=None,
954
+ callback_on_step_end_tensor_inputs=None,
955
+ ):
956
+ if height % 8 != 0 or width % 8 != 0:
957
+ raise ValueError(
958
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
959
+ )
960
+
961
+ if callback_steps is not None and (
962
+ not isinstance(callback_steps, int) or callback_steps <= 0
963
+ ):
964
+ raise ValueError(
965
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
966
+ f" {type(callback_steps)}."
967
+ )
968
+
969
+ if callback_on_step_end_tensor_inputs is not None and not all(
970
+ k in self._callback_tensor_inputs
971
+ for k in callback_on_step_end_tensor_inputs
972
+ ):
973
+ raise ValueError(
974
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
975
+ )
976
+
977
+ if prompt is not None and prompt_embeds is not None:
978
+ raise ValueError(
979
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
980
+ " only forward one of the two."
981
+ )
982
+ elif prompt_2 is not None and prompt_embeds is not None:
983
+ raise ValueError(
984
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
985
+ " only forward one of the two."
986
+ )
987
+ elif prompt is None and prompt_embeds is None:
988
+ raise ValueError(
989
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
990
+ )
991
+ elif prompt is not None and (
992
+ not isinstance(prompt, str) and not isinstance(prompt, list)
993
+ ):
994
+ raise ValueError(
995
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
996
+ )
997
+ elif prompt_2 is not None and (
998
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
999
+ ):
1000
+ raise ValueError(
1001
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
1002
+ )
1003
+
1004
+ if negative_prompt is not None and negative_prompt_embeds is not None:
1005
+ raise ValueError(
1006
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
1007
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
1008
+ )
1009
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
1010
+ raise ValueError(
1011
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
1012
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
1013
+ )
1014
+
1015
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
1016
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
1017
+ raise ValueError(
1018
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
1019
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
1020
+ f" {negative_prompt_embeds.shape}."
1021
+ )
1022
+
1023
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
1024
+ raise ValueError(
1025
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
1026
+ )
1027
+
1028
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
1029
+ raise ValueError(
1030
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
1031
+ )
1032
+
1033
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
1034
+ raise ValueError(
1035
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
1036
+ )
1037
+
1038
+ if ip_adapter_image_embeds is not None:
1039
+ if not isinstance(ip_adapter_image_embeds, list):
1040
+ raise ValueError(
1041
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
1042
+ )
1043
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
1044
+ raise ValueError(
1045
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
1046
+ )
1047
+
1048
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
1049
+ def prepare_latents(
1050
+ self,
1051
+ batch_size,
1052
+ num_channels_latents,
1053
+ height,
1054
+ width,
1055
+ dtype,
1056
+ device,
1057
+ generator,
1058
+ latents=None,
1059
+ ):
1060
+ shape = (
1061
+ batch_size,
1062
+ num_channels_latents,
1063
+ int(height) // self.vae_scale_factor,
1064
+ int(width) // self.vae_scale_factor,
1065
+ )
1066
+ if isinstance(generator, list) and len(generator) != batch_size:
1067
+ raise ValueError(
1068
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1069
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1070
+ )
1071
+
1072
+ if latents is None:
1073
+ latents = randn_tensor(
1074
+ shape, generator=generator, device=device, dtype=dtype
1075
+ )
1076
+ else:
1077
+ latents = latents.to(device)
1078
+
1079
+ # scale the initial noise by the standard deviation required by the scheduler
1080
+ latents = latents * self.scheduler.init_noise_sigma
1081
+ return latents
1082
+
1083
+ def _get_add_time_ids(
1084
+ self,
1085
+ original_size,
1086
+ crops_coords_top_left,
1087
+ target_size,
1088
+ dtype,
1089
+ text_encoder_projection_dim=None,
1090
+ ):
1091
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1092
+
1093
+ passed_add_embed_dim = (
1094
+ self.unet.config.addition_time_embed_dim * len(add_time_ids)
1095
+ + text_encoder_projection_dim
1096
+ )
1097
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
1098
+
1099
+ if expected_add_embed_dim != passed_add_embed_dim:
1100
+ raise ValueError(
1101
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
1102
+ )
1103
+
1104
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1105
+ return add_time_ids
1106
+
1107
+ @property
1108
+ def device(self) -> torch.device:
1109
+ r"""
1110
+ Returns:
1111
+ `torch.device`: The torch device on which the pipeline is located.
1112
+ """
1113
+ module_names, _ = self._get_signature_keys(self)
1114
+ modules = [getattr(self, n, None) for n in module_names]
1115
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
1116
+
1117
+ for module in modules:
1118
+ return module.device
1119
+
1120
+ return torch.device("cpu")
1121
+
1122
+ @property
1123
+ def _execution_device(self):
1124
+ """
1125
+ Returns the device on which the pipeline's models will be executed. After calling
1126
+ [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
1127
+ Accelerate's module hooks.
1128
+ """
1129
+ for name, model in self.components.items():
1130
+ if (
1131
+ not isinstance(model, torch.nn.Module)
1132
+ or name in self._exclude_from_cpu_offload
1133
+ ):
1134
+ continue
1135
+
1136
+ if not hasattr(model, "_hf_hook"):
1137
+ return self.device
1138
+ for module in model.modules():
1139
+ if (
1140
+ hasattr(module, "_hf_hook")
1141
+ and hasattr(module._hf_hook, "execution_device")
1142
+ and module._hf_hook.execution_device is not None
1143
+ ):
1144
+ return torch.device(module._hf_hook.execution_device)
1145
+ return self.device
1146
+
1147
+ def upcast_vae(self):
1148
+ dtype = self.vae.dtype
1149
+ self.vae.to(dtype=torch.float32)
1150
+ use_torch_2_0_or_xformers = isinstance(
1151
+ self.vae.decoder.mid_block.attentions[0].processor,
1152
+ (
1153
+ AttnProcessor2_0,
1154
+ XFormersAttnProcessor,
1155
+ FusedAttnProcessor2_0,
1156
+ ),
1157
+ )
1158
+ # if xformers or torch_2_0 is used attention block does not need
1159
+ # to be in float32 which can save lots of memory
1160
+ if use_torch_2_0_or_xformers:
1161
+ self.vae.post_quant_conv.to(dtype)
1162
+ self.vae.decoder.conv_in.to(dtype)
1163
+ self.vae.decoder.mid_block.to(dtype)
1164
+
1165
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1166
+ def get_guidance_scale_embedding(
1167
+ self,
1168
+ w: torch.Tensor,
1169
+ embedding_dim: int = 512,
1170
+ dtype: torch.dtype = torch.float32,
1171
+ ) -> torch.Tensor:
1172
+ """
1173
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1174
+
1175
+ Args:
1176
+ w (`torch.Tensor`):
1177
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
1178
+ embedding_dim (`int`, *optional*, defaults to 512):
1179
+ Dimension of the embeddings to generate.
1180
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1181
+ Data type of the generated embeddings.
1182
+
1183
+ Returns:
1184
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
1185
+ """
1186
+ assert len(w.shape) == 1
1187
+ w = w * 1000.0
1188
+
1189
+ half_dim = embedding_dim // 2
1190
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1191
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1192
+ emb = w.to(dtype)[:, None] * emb[None, :]
1193
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1194
+ if embedding_dim % 2 == 1: # zero pad
1195
+ emb = torch.nn.functional.pad(emb, (0, 1))
1196
+ assert emb.shape == (w.shape[0], embedding_dim)
1197
+ return emb
1198
+
1199
+ @property
1200
+ def guidance_scale(self):
1201
+ return self._guidance_scale
1202
+
1203
+ @property
1204
+ def guidance_rescale(self):
1205
+ return self._guidance_rescale
1206
+
1207
+ @property
1208
+ def clip_skip(self):
1209
+ return self._clip_skip
1210
+
1211
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1212
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1213
+ # corresponds to doing no classifier free guidance.
1214
+ @property
1215
+ def do_classifier_free_guidance(self):
1216
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1217
+
1218
+ @property
1219
+ def cross_attention_kwargs(self):
1220
+ return self._cross_attention_kwargs
1221
+
1222
+ @property
1223
+ def denoising_end(self):
1224
+ return self._denoising_end
1225
+
1226
+ @property
1227
+ def num_timesteps(self):
1228
+ return self._num_timesteps
1229
+
1230
+ @property
1231
+ def interrupt(self):
1232
+ return self._interrupt
1233
+
1234
+ @torch.no_grad()
1235
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1236
+ def __call__(
1237
+ self,
1238
+ prompt: Union[str, List[str]] = None,
1239
+ prompt_2: Optional[Union[str, List[str]]] = None,
1240
+ height: Optional[int] = None,
1241
+ width: Optional[int] = None,
1242
+ num_inference_steps: int = 50,
1243
+ timesteps: List[int] = None,
1244
+ sigmas: List[float] = None,
1245
+ denoising_end: Optional[float] = None,
1246
+ guidance_scale: float = 5.0,
1247
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1248
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1249
+ num_images_per_prompt: Optional[int] = 1,
1250
+ eta: float = 0.0,
1251
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1252
+ latents: Optional[torch.Tensor] = None,
1253
+ prompt_embeds: Optional[torch.Tensor] = None,
1254
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1255
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1256
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1257
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1258
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1259
+ output_type: Optional[str] = "pil",
1260
+ return_dict: bool = True,
1261
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1262
+ guidance_rescale: float = 0.0,
1263
+ original_size: Optional[Tuple[int, int]] = None,
1264
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1265
+ target_size: Optional[Tuple[int, int]] = None,
1266
+ negative_original_size: Optional[Tuple[int, int]] = None,
1267
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1268
+ negative_target_size: Optional[Tuple[int, int]] = None,
1269
+ clip_skip: Optional[int] = None,
1270
+ callback_on_step_end: Optional[
1271
+ Union[
1272
+ Callable[[int, int, Dict], None],
1273
+ PipelineCallback,
1274
+ MultiPipelineCallbacks,
1275
+ ]
1276
+ ] = None,
1277
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1278
+ **kwargs,
1279
+ ):
1280
+ r"""
1281
+ Function invoked when calling the pipeline for generation.
1282
+
1283
+ Args:
1284
+ prompt (`str` or `List[str]`, *optional*):
1285
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1286
+ instead.
1287
+ prompt_2 (`str` or `List[str]`, *optional*):
1288
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1289
+ used in both text-encoders
1290
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1291
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1292
+ Anything below 512 pixels won't work well for
1293
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1294
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1295
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1296
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1297
+ Anything below 512 pixels won't work well for
1298
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1299
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1300
+ num_inference_steps (`int`, *optional*, defaults to 50):
1301
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1302
+ expense of slower inference.
1303
+ timesteps (`List[int]`, *optional*):
1304
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1305
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1306
+ passed will be used. Must be in descending order.
1307
+ sigmas (`List[float]`, *optional*):
1308
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1309
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1310
+ will be used.
1311
+ denoising_end (`float`, *optional*):
1312
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1313
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1314
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1315
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1316
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1317
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1318
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1319
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1320
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1321
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1322
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1323
+ usually at the expense of lower image quality.
1324
+ negative_prompt (`str` or `List[str]`, *optional*):
1325
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1326
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1327
+ less than `1`).
1328
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1329
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1330
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1331
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1332
+ The number of images to generate per prompt.
1333
+ eta (`float`, *optional*, defaults to 0.0):
1334
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1335
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1336
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1337
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1338
+ to make generation deterministic.
1339
+ latents (`torch.Tensor`, *optional*):
1340
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1341
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1342
+ tensor will ge generated by sampling using the supplied random `generator`.
1343
+ prompt_embeds (`torch.Tensor`, *optional*):
1344
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1345
+ provided, text embeddings will be generated from `prompt` input argument.
1346
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1347
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1348
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1349
+ argument.
1350
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1351
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1352
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1353
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1354
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1355
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1356
+ input argument.
1357
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1358
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1359
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1360
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1361
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1362
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1363
+ output_type (`str`, *optional*, defaults to `"pil"`):
1364
+ The output format of the generate image. Choose between
1365
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1366
+ return_dict (`bool`, *optional*, defaults to `True`):
1367
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1368
+ of a plain tuple.
1369
+ cross_attention_kwargs (`dict`, *optional*):
1370
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1371
+ `self.processor` in
1372
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1373
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
1374
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1375
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
1376
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
1377
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
1378
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1379
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1380
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1381
+ explained in section 2.2 of
1382
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1383
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1384
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1385
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1386
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1387
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1388
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1389
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1390
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1391
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1392
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1393
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1394
+ micro-conditioning as explained in section 2.2 of
1395
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1396
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1397
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1398
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1399
+ micro-conditioning as explained in section 2.2 of
1400
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1401
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1402
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1403
+ To negatively condition the generation process based on a target image resolution. It should be as same
1404
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1405
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1406
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1407
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1408
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1409
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1410
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1411
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1412
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1413
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1414
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1415
+ `._callback_tensor_inputs` attribute of your pipeline class.
1416
+
1417
+ Examples:
1418
+
1419
+ Returns:
1420
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
1421
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1422
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1423
+ """
1424
+
1425
+ callback = kwargs.pop("callback", None)
1426
+ callback_steps = kwargs.pop("callback_steps", None)
1427
+
1428
+ if callback is not None:
1429
+ deprecate(
1430
+ "callback",
1431
+ "1.0.0",
1432
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1433
+ )
1434
+ if callback_steps is not None:
1435
+ deprecate(
1436
+ "callback_steps",
1437
+ "1.0.0",
1438
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1439
+ )
1440
+
1441
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1442
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1443
+
1444
+ # 0. Default height and width to unet
1445
+ height = height or self.default_sample_size * self.vae_scale_factor
1446
+ width = width or self.default_sample_size * self.vae_scale_factor
1447
+
1448
+ original_size = original_size or (height, width)
1449
+ target_size = target_size or (height, width)
1450
+
1451
+ # 1. Check inputs. Raise error if not correct
1452
+ self.check_inputs(
1453
+ prompt,
1454
+ prompt_2,
1455
+ height,
1456
+ width,
1457
+ callback_steps,
1458
+ negative_prompt,
1459
+ negative_prompt_2,
1460
+ prompt_embeds,
1461
+ negative_prompt_embeds,
1462
+ pooled_prompt_embeds,
1463
+ negative_pooled_prompt_embeds,
1464
+ ip_adapter_image,
1465
+ ip_adapter_image_embeds,
1466
+ callback_on_step_end_tensor_inputs,
1467
+ )
1468
+
1469
+ self._guidance_scale = guidance_scale
1470
+ self._guidance_rescale = guidance_rescale
1471
+ self._clip_skip = clip_skip
1472
+ self._cross_attention_kwargs = cross_attention_kwargs
1473
+ self._denoising_end = denoising_end
1474
+ self._interrupt = False
1475
+
1476
+ # 2. Define call parameters
1477
+ if prompt is not None and isinstance(prompt, str):
1478
+ batch_size = 1
1479
+ elif prompt is not None and isinstance(prompt, list):
1480
+ batch_size = len(prompt)
1481
+ else:
1482
+ batch_size = prompt_embeds.shape[0]
1483
+
1484
+ device = self._execution_device
1485
+
1486
+ # 3. Encode input prompt
1487
+ lora_scale = (
1488
+ self.cross_attention_kwargs.get("scale", None)
1489
+ if self.cross_attention_kwargs is not None
1490
+ else None
1491
+ )
1492
+
1493
+ (
1494
+ prompt_embeds,
1495
+ negative_prompt_embeds,
1496
+ pooled_prompt_embeds,
1497
+ negative_pooled_prompt_embeds,
1498
+ ) = self.encode_prompt(
1499
+ prompt=prompt,
1500
+ prompt_2=prompt_2,
1501
+ device=device,
1502
+ num_images_per_prompt=num_images_per_prompt,
1503
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1504
+ negative_prompt=negative_prompt,
1505
+ negative_prompt_2=negative_prompt_2,
1506
+ prompt_embeds=prompt_embeds,
1507
+ negative_prompt_embeds=negative_prompt_embeds,
1508
+ pooled_prompt_embeds=pooled_prompt_embeds,
1509
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1510
+ lora_scale=lora_scale,
1511
+ clip_skip=self.clip_skip,
1512
+ )
1513
+
1514
+ # 4. Prepare timesteps
1515
+ timesteps, num_inference_steps = retrieve_timesteps(
1516
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1517
+ )
1518
+
1519
+ # 5. Prepare latent variables
1520
+ num_channels_latents = self.unet.config.in_channels
1521
+ # breakpoint()
1522
+ latents = self.prepare_latents(
1523
+ batch_size * num_images_per_prompt,
1524
+ num_channels_latents,
1525
+ height,
1526
+ width,
1527
+ prompt_embeds.dtype,
1528
+ device,
1529
+ generator,
1530
+ latents,
1531
+ )
1532
+
1533
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1534
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1535
+
1536
+ # 7. Prepare added time ids & embeddings
1537
+ add_text_embeds = pooled_prompt_embeds
1538
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1539
+
1540
+ add_time_ids = self._get_add_time_ids(
1541
+ original_size,
1542
+ crops_coords_top_left,
1543
+ target_size,
1544
+ dtype=prompt_embeds.dtype,
1545
+ text_encoder_projection_dim=text_encoder_projection_dim,
1546
+ )
1547
+ if negative_original_size is not None and negative_target_size is not None:
1548
+ negative_add_time_ids = self._get_add_time_ids(
1549
+ negative_original_size,
1550
+ negative_crops_coords_top_left,
1551
+ negative_target_size,
1552
+ dtype=prompt_embeds.dtype,
1553
+ text_encoder_projection_dim=text_encoder_projection_dim,
1554
+ )
1555
+ else:
1556
+ negative_add_time_ids = add_time_ids
1557
+
1558
+ if self.do_classifier_free_guidance:
1559
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1560
+ add_text_embeds = torch.cat(
1561
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
1562
+ )
1563
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1564
+
1565
+ prompt_embeds = prompt_embeds.to(device)
1566
+ add_text_embeds = add_text_embeds.to(device)
1567
+ add_time_ids = add_time_ids.to(device).repeat(
1568
+ batch_size * num_images_per_prompt, 1
1569
+ )
1570
+
1571
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1572
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1573
+ ip_adapter_image,
1574
+ ip_adapter_image_embeds,
1575
+ device,
1576
+ batch_size * num_images_per_prompt,
1577
+ self.do_classifier_free_guidance,
1578
+ )
1579
+
1580
+ # 8. Denoising loop
1581
+ num_warmup_steps = max(
1582
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
1583
+ )
1584
+
1585
+ # 8.1 Apply denoising_end
1586
+ if (
1587
+ self.denoising_end is not None
1588
+ and isinstance(self.denoising_end, float)
1589
+ and self.denoising_end > 0
1590
+ and self.denoising_end < 1
1591
+ ):
1592
+ discrete_timestep_cutoff = int(
1593
+ round(
1594
+ self.scheduler.config.num_train_timesteps
1595
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1596
+ )
1597
+ )
1598
+ num_inference_steps = len(
1599
+ list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
1600
+ )
1601
+ timesteps = timesteps[:num_inference_steps]
1602
+
1603
+ # 9. Optionally get Guidance Scale Embedding
1604
+ timestep_cond = None
1605
+ if self.unet.config.time_cond_proj_dim is not None:
1606
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
1607
+ batch_size * num_images_per_prompt
1608
+ )
1609
+ timestep_cond = self.get_guidance_scale_embedding(
1610
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1611
+ ).to(device=device, dtype=latents.dtype)
1612
+
1613
+ self._num_timesteps = len(timesteps)
1614
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1615
+ for i, t in enumerate(timesteps):
1616
+ if self.interrupt:
1617
+ continue
1618
+
1619
+ # expand the latents if we are doing classifier free guidance
1620
+ latent_model_input = (
1621
+ torch.cat([latents] * 2)
1622
+ if self.do_classifier_free_guidance
1623
+ else latents
1624
+ )
1625
+
1626
+ latent_model_input = self.scheduler.scale_model_input(
1627
+ latent_model_input, t
1628
+ )
1629
+
1630
+ # predict the noise residual
1631
+ added_cond_kwargs = {
1632
+ "text_embeds": add_text_embeds,
1633
+ "time_ids": add_time_ids,
1634
+ }
1635
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1636
+ added_cond_kwargs["image_embeds"] = image_embeds
1637
+ noise_pred = self.unet(
1638
+ latent_model_input,
1639
+ t,
1640
+ encoder_hidden_states=prompt_embeds,
1641
+ timestep_cond=timestep_cond,
1642
+ cross_attention_kwargs=self.cross_attention_kwargs,
1643
+ added_cond_kwargs=added_cond_kwargs,
1644
+ return_dict=False,
1645
+ )[0]
1646
+
1647
+ # perform guidance
1648
+ if self.do_classifier_free_guidance:
1649
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1650
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1651
+ noise_pred_text - noise_pred_uncond
1652
+ )
1653
+
1654
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1655
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1656
+ noise_pred = rescale_noise_cfg(
1657
+ noise_pred,
1658
+ noise_pred_text,
1659
+ guidance_rescale=self.guidance_rescale,
1660
+ )
1661
+
1662
+ # compute the previous noisy sample x_t -> x_t-1
1663
+ latents_dtype = latents.dtype
1664
+ latents = self.scheduler.step(
1665
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1666
+ )[0]
1667
+ if latents.dtype != latents_dtype:
1668
+ if torch.backends.mps.is_available():
1669
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1670
+ latents = latents.to(latents_dtype)
1671
+
1672
+ if callback_on_step_end is not None:
1673
+ callback_kwargs = {}
1674
+ for k in callback_on_step_end_tensor_inputs:
1675
+ callback_kwargs[k] = locals()[k]
1676
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1677
+
1678
+ latents = callback_outputs.pop("latents", latents)
1679
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1680
+ negative_prompt_embeds = callback_outputs.pop(
1681
+ "negative_prompt_embeds", negative_prompt_embeds
1682
+ )
1683
+ add_text_embeds = callback_outputs.pop(
1684
+ "add_text_embeds", add_text_embeds
1685
+ )
1686
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1687
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1688
+ )
1689
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1690
+ negative_add_time_ids = callback_outputs.pop(
1691
+ "negative_add_time_ids", negative_add_time_ids
1692
+ )
1693
+
1694
+ # call the callback, if provided
1695
+ if i == len(timesteps) - 1 or (
1696
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1697
+ ):
1698
+ progress_bar.update()
1699
+ if callback is not None and i % callback_steps == 0:
1700
+ step_idx = i // getattr(self.scheduler, "order", 1)
1701
+ callback(step_idx, t, latents)
1702
+
1703
+ if XLA_AVAILABLE:
1704
+ xm.mark_step()
1705
+
1706
+ if not output_type == "latent":
1707
+ # make sure the VAE is in float32 mode, as it overflows in float16
1708
+ needs_upcasting = (
1709
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1710
+ )
1711
+
1712
+ if needs_upcasting:
1713
+ self.upcast_vae()
1714
+ latents = latents.to(
1715
+ next(iter(self.vae.post_quant_conv.parameters())).dtype
1716
+ )
1717
+ elif latents.dtype != self.vae.dtype:
1718
+ if torch.backends.mps.is_available():
1719
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1720
+ self.vae = self.vae.to(latents.dtype)
1721
+
1722
+ # unscale/denormalize the latents
1723
+ # denormalize with the mean and std if available and not None
1724
+ has_latents_mean = (
1725
+ hasattr(self.vae.config, "latents_mean")
1726
+ and self.vae.config.latents_mean is not None
1727
+ )
1728
+ has_latents_std = (
1729
+ hasattr(self.vae.config, "latents_std")
1730
+ and self.vae.config.latents_std is not None
1731
+ )
1732
+ if has_latents_mean and has_latents_std:
1733
+ latents_mean = (
1734
+ torch.tensor(self.vae.config.latents_mean)
1735
+ .view(1, 4, 1, 1)
1736
+ .to(latents.device, latents.dtype)
1737
+ )
1738
+ latents_std = (
1739
+ torch.tensor(self.vae.config.latents_std)
1740
+ .view(1, 4, 1, 1)
1741
+ .to(latents.device, latents.dtype)
1742
+ )
1743
+ latents = (
1744
+ latents * latents_std / self.vae.config.scaling_factor
1745
+ + latents_mean
1746
+ )
1747
+ else:
1748
+ latents = latents / self.vae.config.scaling_factor
1749
+
1750
+ image = self.vae.decode(latents)
1751
+
1752
+ # cast back to fp16 if needed
1753
+ if needs_upcasting:
1754
+ self.vae.to(dtype=torch.float16)
1755
+ else:
1756
+ image = latents
1757
+
1758
+ if not output_type == "latent":
1759
+ # apply watermark if available
1760
+ # if self.watermark is not None:
1761
+ # image = self.watermark.apply_watermark(image)
1762
+
1763
+ image = self.image_processor.postprocess(image, output_type=output_type)
1764
+
1765
+ # Offload all models
1766
+ self.maybe_free_model_hooks()
1767
+
1768
+ if not return_dict:
1769
+ return (image,)
1770
+
1771
+ return AniMemoryPipelineOutput(images=image)
scheduler/scheduler_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_class_name": "EulerDiscreteScheduler",
3
  "_diffusers_version": "0.26.0",
4
  "beta_end": 0.012,
5
  "beta_schedule": "scaled_linear",
 
1
  {
2
+ "_class_name": "EulerAncestralDiscreteXPredScheduler",
3
  "_diffusers_version": "0.26.0",
4
  "beta_end": 0.012,
5
  "beta_schedule": "scaled_linear",
scheduler/scheduling_euler_ancestral_discrete_x_pred.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Katherine Crowson, AniMemory Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers.utils.torch_utils import randn_tensor
22
+ from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
23
+ EulerAncestralDiscreteScheduler,
24
+ EulerAncestralDiscreteSchedulerOutput,
25
+ rescale_zero_terminal_snr,
26
+ )
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class EulerAncestralDiscreteXPredScheduler(EulerAncestralDiscreteScheduler):
33
+ """
34
+ Ancestral sampling with Euler method steps. This model inherits from [`EulerAncestralDiscreteScheduler`]. Check the
35
+ superclass documentation for the args and returns.
36
+
37
+ For more details, see the original paper: https://arxiv.org/abs/2403.08381
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_train_timesteps: int = 1000,
43
+ beta_start: float = 0.0001,
44
+ beta_end: float = 0.02,
45
+ beta_schedule: str = "linear",
46
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
47
+ prediction_type: str = "epsilon",
48
+ timestep_spacing: str = "linspace",
49
+ steps_offset: int = 0,
50
+ ):
51
+ super(EulerAncestralDiscreteXPredScheduler, self).__init__(
52
+ num_train_timesteps,
53
+ beta_start,
54
+ beta_end,
55
+ beta_schedule,
56
+ trained_betas,
57
+ prediction_type,
58
+ timestep_spacing,
59
+ steps_offset,
60
+ )
61
+
62
+ sigmas = np.array((1 - self.alphas_cumprod) ** 0.5, dtype=np.float32)
63
+ self.sigmas = torch.from_numpy(sigmas)
64
+
65
+ def rescale_betas_zero_snr(self):
66
+ self.betas = rescale_zero_terminal_snr(self.betas)
67
+ self.alphas = 1.0 - self.betas
68
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
69
+ sigmas = np.array((1 - self.alphas_cumprod) ** 0.5)
70
+ self.sigmas = torch.from_numpy(sigmas)
71
+
72
+ @property
73
+ def init_noise_sigma(self):
74
+ return 1.0
75
+
76
+ def scale_model_input(
77
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
78
+ ) -> torch.FloatTensor:
79
+ self.is_scale_input_called = True
80
+ # standard deviation of the initial noise distribution
81
+ return sample
82
+
83
+ def set_timesteps(
84
+ self, num_inference_steps: int, device: Union[str, torch.device] = None
85
+ ):
86
+ """
87
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
88
+
89
+ Args:
90
+ num_inference_steps (`int`):
91
+ the number of diffusion steps used when generating samples with a pre-trained model.
92
+ device (`str` or `torch.device`, optional):
93
+ the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ """
95
+ self.num_inference_steps = num_inference_steps
96
+
97
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
98
+ if self.config.timestep_spacing == "linspace":
99
+ timesteps = np.linspace(
100
+ 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float
101
+ )[::-1].copy()
102
+ elif self.config.timestep_spacing == "leading":
103
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
104
+ # creates integer timesteps by multiplying by ratio
105
+ # casting to int to avoid issues when num_inference_step is power of 3
106
+ timesteps = (
107
+ (np.arange(0, num_inference_steps) * step_ratio)
108
+ .round()[::-1]
109
+ .copy()
110
+ .astype(float)
111
+ )
112
+ timesteps += self.config.steps_offset
113
+ elif self.config.timestep_spacing == "trailing":
114
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
115
+ # creates integer timesteps by multiplying by ratio
116
+ # casting to int to avoid issues when num_inference_step is power of 3
117
+ timesteps = (
118
+ (np.arange(self.config.num_train_timesteps, 0, -step_ratio))
119
+ .round()
120
+ .copy()
121
+ .astype(float)
122
+ )
123
+ timesteps -= 1
124
+ else:
125
+ raise ValueError(
126
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
127
+ )
128
+
129
+ sigmas = np.array((1 - self.alphas_cumprod) ** 0.5)
130
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
131
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
132
+
133
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
134
+ if str(device).startswith("mps"):
135
+ # mps does not support float64
136
+ self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
137
+ else:
138
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
139
+
140
+ def step(
141
+ self,
142
+ model_output: torch.FloatTensor,
143
+ timestep: Union[float, torch.FloatTensor],
144
+ sample: torch.FloatTensor,
145
+ generator: Optional[torch.Generator] = None,
146
+ return_dict: bool = True,
147
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
148
+ """
149
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
150
+ process from the learned model outputs (most often the predicted noise).
151
+
152
+ Args:
153
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
154
+ timestep (`float`): current timestep in the diffusion chain.
155
+ sample (`torch.FloatTensor`):
156
+ current instance of sample being created by diffusion process.
157
+ generator (`torch.Generator`, optional): Random number generator.
158
+ return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
159
+
160
+ Returns:
161
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
162
+ [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
163
+ a `tuple`. When returning a tuple, the first element is the sample tensor.
164
+
165
+ """
166
+
167
+ if (
168
+ isinstance(timestep, int)
169
+ or isinstance(timestep, torch.IntTensor)
170
+ or isinstance(timestep, torch.LongTensor)
171
+ ):
172
+ raise ValueError(
173
+ (
174
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
175
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
176
+ " one of the `scheduler.timesteps` as a timestep."
177
+ ),
178
+ )
179
+
180
+ if isinstance(timestep, torch.Tensor):
181
+ timestep = timestep.to(self.timesteps.device)
182
+
183
+ step_index = (self.timesteps == timestep).nonzero().item()
184
+
185
+ if self.config.prediction_type == "sample":
186
+ pred_original_sample = model_output
187
+ else:
188
+ raise ValueError(
189
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
190
+ )
191
+
192
+ sigma_t = self.sigmas[step_index]
193
+ sigma_s = self.sigmas[step_index + 1]
194
+ alpha_t = (1 - sigma_t**2) ** 0.5
195
+ alpha_s = (1 - sigma_s**2) ** 0.5
196
+
197
+ coef_sample = (sigma_s / sigma_t) ** 2 * alpha_t / alpha_s
198
+ coef_noise = (sigma_s / sigma_t) * (1 - (alpha_t / alpha_s) ** 2) ** 0.5
199
+ coef_x = alpha_s * (1 - alpha_t**2 / alpha_s**2) / sigma_t**2
200
+
201
+ device = model_output.device
202
+ noise = randn_tensor(
203
+ model_output.shape,
204
+ dtype=model_output.dtype,
205
+ device=device,
206
+ generator=generator,
207
+ )
208
+ prev_sample = (
209
+ coef_sample * sample + coef_x * pred_original_sample + coef_noise * noise
210
+ )
211
+
212
+ if not return_dict:
213
+ return (prev_sample,)
214
+
215
+ return EulerAncestralDiscreteSchedulerOutput(
216
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
217
+ )
218
+
219
+ def add_noise(
220
+ self,
221
+ original_samples: torch.FloatTensor,
222
+ noise: torch.FloatTensor,
223
+ timesteps: torch.FloatTensor,
224
+ ) -> torch.FloatTensor:
225
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
226
+ sigmas = self.sigmas.to(
227
+ device=original_samples.device, dtype=original_samples.dtype
228
+ )
229
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
230
+ # mps does not support float64
231
+ schedule_timesteps = self.timesteps.to(
232
+ original_samples.device, dtype=torch.float32
233
+ )
234
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
235
+ else:
236
+ schedule_timesteps = self.timesteps.to(original_samples.device)
237
+ timesteps = timesteps.to(original_samples.device)
238
+
239
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
240
+
241
+ sigma = sigmas[step_indices].flatten()
242
+ while len(sigma.shape) < len(original_samples.shape):
243
+ sigma = sigma.unsqueeze(-1)
244
+
245
+ noisy_samples = original_samples + noise * sigma
246
+ return noisy_samples
text_encoder/animemory_t5.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 AniMemory Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import torch
18
+ from safetensors.torch import load_file
19
+ from transformers.models.t5.configuration_t5 import T5Config
20
+ from transformers.models.t5.modeling_t5 import T5Stack
21
+
22
+
23
+ class AniMemoryT5(torch.nn.Module):
24
+ def __init__(self, config: T5Config, embed_tokens=None):
25
+ super().__init__()
26
+ self.encoder = T5Stack(config, embed_tokens)
27
+ self.embed_tokens_encoder = torch.nn.Embedding(250002, 4096, padding_idx=1)
28
+
29
+ @classmethod
30
+ def from_pretrained(
31
+ cls,
32
+ pretrained_model_name_or_path,
33
+ subfolder="",
34
+ embed_tokens=None,
35
+ emb_name="weights.safetensors",
36
+ torch_dtype=torch.float16,
37
+ ):
38
+ cls.dtype = torch_dtype
39
+ config = T5Stack.config_class.from_pretrained(
40
+ pretrained_model_name_or_path, subfolder=subfolder
41
+ )
42
+ model = cls(config=config, embed_tokens=embed_tokens)
43
+ model.encoder = T5Stack.from_pretrained(
44
+ pretrained_model_name_or_path, subfolder=subfolder
45
+ )
46
+ embed_tokens_encoder_path = load_file(
47
+ os.path.join(pretrained_model_name_or_path, subfolder, emb_name)
48
+ )
49
+ model.embed_tokens_encoder.load_state_dict(embed_tokens_encoder_path)
50
+ model.encoder.to(torch_dtype)
51
+ model.embed_tokens_encoder.to(torch_dtype)
52
+ return model
53
+
54
+ def to(self, *args, **kwargs):
55
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
56
+ *args, **kwargs
57
+ )
58
+ super(AniMemoryT5, self).to(*args, **kwargs)
59
+ self.dtype = dtype if dtype is not None else self.dtype
60
+ self.device = device if device is not None else self.device
61
+ return self
62
+
63
+ def make_attn_mask(self, attn_mask):
64
+ seq_len = attn_mask.shape[1]
65
+ query = attn_mask.unsqueeze(1).float()
66
+ attn_mask = (
67
+ query.repeat([1, seq_len, 1]).unsqueeze(1).repeat([1, self.num_head, 1, 1])
68
+ )
69
+ attn_mask = attn_mask.view([-1, seq_len, seq_len])
70
+ return attn_mask
71
+
72
+ def forward(self, text, attention_mask):
73
+ embeddings = self.embed_tokens_encoder(text)
74
+ encoder_outputs = self.encoder(
75
+ inputs_embeds=embeddings,
76
+ attention_mask=attention_mask,
77
+ output_hidden_states=True,
78
+ )
79
+ hidden_states = encoder_outputs.hidden_states[-2]
80
+ hidden_states = self.encoder.final_layer_norm(hidden_states)
81
+ return hidden_states, hidden_states
text_encoder_2/animemory_altclip.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 AniMemory Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import torch
18
+ from safetensors.torch import load_file
19
+ from transformers import CLIPTextConfig, CLIPTextModelWithProjection
20
+
21
+
22
+ class AniMemoryAltCLip(torch.nn.Module):
23
+ def __init__(self, config: CLIPTextConfig):
24
+ super().__init__()
25
+ self.model_hf = CLIPTextModelWithProjection(config)
26
+ self.linear_proj = torch.nn.Linear(in_features=1280, out_features=1280)
27
+
28
+ @classmethod
29
+ def from_pretrained(
30
+ cls,
31
+ pretrained_model_name_or_path,
32
+ subfolder="",
33
+ linear_proj_name="weights.safetensors",
34
+ torch_dtype=torch.float16,
35
+ ):
36
+ cls.dtype = torch_dtype
37
+ config = CLIPTextModelWithProjection.config_class.from_pretrained(
38
+ pretrained_model_name_or_path, subfolder=subfolder
39
+ )
40
+ model = cls(config=config)
41
+ model.model_hf = CLIPTextModelWithProjection.from_pretrained(
42
+ pretrained_model_name_or_path, subfolder=subfolder
43
+ )
44
+ linear_proj_state = load_file(
45
+ os.path.join(pretrained_model_name_or_path, subfolder, linear_proj_name)
46
+ )
47
+ model.linear_proj.load_state_dict(linear_proj_state)
48
+ return model
49
+
50
+ def to(self, *args, **kwargs):
51
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
52
+ *args, **kwargs
53
+ )
54
+ super(AniMemoryAltCLip, self).to(*args, **kwargs)
55
+ self.dtype = dtype if dtype is not None else self.dtype
56
+ self.device = device if device is not None else self.device
57
+ return self
58
+
59
+ def expand_mask(self, mask=None, dtype="", tgt_len=None):
60
+ """
61
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
62
+ """
63
+ bsz, src_len = mask.size()
64
+ tgt_len = tgt_len if tgt_len is not None else src_len
65
+
66
+ expanded_mask = (
67
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
68
+ )
69
+
70
+ inverted_mask = 1.0 - expanded_mask
71
+
72
+ return inverted_mask.masked_fill(
73
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
74
+ )
75
+
76
+ def make_attn_mask(self, attn_mask):
77
+ seq_len = attn_mask.shape[1]
78
+ query = attn_mask.unsqueeze(1).float()
79
+ attn_mask = (
80
+ query.repeat([1, seq_len, 1]).unsqueeze(1).repeat([1, self.num_head, 1, 1])
81
+ )
82
+ attn_mask = attn_mask.view([-1, seq_len, seq_len])
83
+ return attn_mask
84
+
85
+ def gradient_checkpointing_enable(
86
+ self,
87
+ ):
88
+ self.model_hf.gradient_checkpointing_enable()
89
+
90
+ def forward(self, text, attention_mask):
91
+ hidden_states = self.model_hf.text_model.embeddings(
92
+ input_ids=text, position_ids=None
93
+ )
94
+ if attention_mask is None:
95
+ print("Warning: attention_mask is None in altclip!")
96
+ new_attn_mask = (
97
+ self.expand_mask(attention_mask, hidden_states.dtype)
98
+ if attention_mask is not None
99
+ else None
100
+ )
101
+ encoder_outputs = self.model_hf.text_model.encoder(
102
+ inputs_embeds=hidden_states,
103
+ attention_mask=new_attn_mask,
104
+ causal_attention_mask=None,
105
+ output_attentions=False,
106
+ output_hidden_states=True,
107
+ return_dict=True,
108
+ )
109
+ last_hidden_state = encoder_outputs[0]
110
+ last_hidden_state = self.model_hf.text_model.final_layer_norm(last_hidden_state)
111
+ last_hidden_state = (
112
+ last_hidden_state[torch.arange(last_hidden_state.shape[0]), 0]
113
+ @ self.model_hf.text_projection.weight
114
+ )
115
+ pooled_output = self.linear_proj(last_hidden_state)
116
+
117
+ extra_features = encoder_outputs.hidden_states[-2]
118
+ extra_features = self.model_hf.text_model.final_layer_norm(extra_features)
119
+ return extra_features, pooled_output
vae/modeling_movq.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Kandinsky 3.0 Model Team, AniMemory Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from types import SimpleNamespace
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from packaging import version
22
+ from safetensors.torch import load_file
23
+
24
+ from diffusers.utils.accelerate_utils import apply_forward_hook
25
+
26
+
27
+ def nonlinearity(x):
28
+ return x * torch.sigmoid(x)
29
+
30
+
31
+ class SpatialNorm(nn.Module):
32
+ def __init__(
33
+ self,
34
+ f_channels,
35
+ zq_channels=None,
36
+ norm_layer=nn.GroupNorm,
37
+ freeze_norm_layer=False,
38
+ add_conv=False,
39
+ **norm_layer_params,
40
+ ):
41
+ super().__init__()
42
+ self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
43
+ if zq_channels is not None:
44
+ if freeze_norm_layer:
45
+ for p in self.norm_layer.parameters:
46
+ p.requires_grad = False
47
+ self.add_conv = add_conv
48
+ if self.add_conv:
49
+ self.conv = nn.Conv2d(
50
+ zq_channels, zq_channels, kernel_size=3, stride=1, padding=1
51
+ )
52
+ self.conv_y = nn.Conv2d(
53
+ zq_channels, f_channels, kernel_size=1, stride=1, padding=0
54
+ )
55
+ self.conv_b = nn.Conv2d(
56
+ zq_channels, f_channels, kernel_size=1, stride=1, padding=0
57
+ )
58
+
59
+ def forward(self, f, zq=None):
60
+ norm_f = self.norm_layer(f)
61
+ if zq is not None:
62
+ f_size = f.shape[-2:]
63
+ if (
64
+ version.parse(torch.__version__) < version.parse("2.1")
65
+ and zq.dtype == torch.bfloat16
66
+ ):
67
+ zq = zq.to(dtype=torch.float32)
68
+ zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
69
+ zq = zq.to(dtype=torch.bfloat16)
70
+ else:
71
+ zq = torch.nn.functional.interpolate(zq, size=f_size, mode="nearest")
72
+ if self.add_conv:
73
+ zq = self.conv(zq)
74
+ norm_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
75
+ return norm_f
76
+
77
+
78
+ def Normalize(in_channels, zq_ch=None, add_conv=None):
79
+ return SpatialNorm(
80
+ in_channels,
81
+ zq_ch,
82
+ norm_layer=nn.GroupNorm,
83
+ freeze_norm_layer=False,
84
+ add_conv=add_conv,
85
+ num_groups=32,
86
+ eps=1e-6,
87
+ affine=True,
88
+ )
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ def __init__(self, in_channels, with_conv):
93
+ super().__init__()
94
+ self.with_conv = with_conv
95
+ if self.with_conv:
96
+ self.conv = torch.nn.Conv2d(
97
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
98
+ )
99
+
100
+ def forward(self, x):
101
+ if (
102
+ version.parse(torch.__version__) < version.parse("2.1")
103
+ and x.dtype == torch.bfloat16
104
+ ):
105
+ x = x.to(dtype=torch.float32)
106
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
107
+ x = x.to(dtype=torch.bfloat16)
108
+ else:
109
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
110
+ if self.with_conv:
111
+ x = self.conv(x)
112
+ return x
113
+
114
+
115
+ class Downsample(nn.Module):
116
+ def __init__(self, in_channels, with_conv):
117
+ super().__init__()
118
+ self.with_conv = with_conv
119
+ if self.with_conv:
120
+ self.conv = torch.nn.Conv2d(
121
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
122
+ )
123
+
124
+ def forward(self, x):
125
+ if self.with_conv:
126
+ pad = (0, 1, 0, 1)
127
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
128
+ x = self.conv(x)
129
+ else:
130
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
131
+ return x
132
+
133
+
134
+ class ResnetBlock(nn.Module):
135
+ def __init__(
136
+ self,
137
+ *,
138
+ in_channels,
139
+ out_channels=None,
140
+ conv_shortcut=False,
141
+ dropout,
142
+ temb_channels=512,
143
+ zq_ch=None,
144
+ add_conv=False,
145
+ ):
146
+ super().__init__()
147
+ self.in_channels = in_channels
148
+ out_channels = in_channels if out_channels is None else out_channels
149
+ self.out_channels = out_channels
150
+ self.use_conv_shortcut = conv_shortcut
151
+
152
+ self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
153
+ self.conv1 = torch.nn.Conv2d(
154
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
155
+ )
156
+ if temb_channels > 0:
157
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
158
+ self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
159
+ self.dropout = torch.nn.Dropout(dropout)
160
+ self.conv2 = torch.nn.Conv2d(
161
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
162
+ )
163
+ if self.in_channels != self.out_channels:
164
+ if self.use_conv_shortcut:
165
+ self.conv_shortcut = torch.nn.Conv2d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+ else:
169
+ self.nin_shortcut = torch.nn.Conv2d(
170
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
171
+ )
172
+
173
+ def forward(self, x, temb, zq=None):
174
+ h = x
175
+ h = self.norm1(h, zq)
176
+ h = nonlinearity(h)
177
+ h = self.conv1(h)
178
+
179
+ if temb is not None:
180
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
181
+
182
+ h = self.norm2(h, zq)
183
+ h = nonlinearity(h)
184
+ h = self.dropout(h)
185
+ h = self.conv2(h)
186
+
187
+ if self.in_channels != self.out_channels:
188
+ if self.use_conv_shortcut:
189
+ x = self.conv_shortcut(x)
190
+ else:
191
+ x = self.nin_shortcut(x)
192
+
193
+ return x + h
194
+
195
+
196
+ class AttnBlock(nn.Module):
197
+ def __init__(self, in_channels, zq_ch=None, add_conv=False):
198
+ super().__init__()
199
+ self.in_channels = in_channels
200
+
201
+ self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
202
+ self.q = torch.nn.Conv2d(
203
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
204
+ )
205
+ self.k = torch.nn.Conv2d(
206
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
207
+ )
208
+ self.v = torch.nn.Conv2d(
209
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
210
+ )
211
+ self.proj_out = torch.nn.Conv2d(
212
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
213
+ )
214
+
215
+ def forward(self, x, zq=None):
216
+ h_ = x
217
+ h_ = self.norm(h_, zq)
218
+ q = self.q(h_)
219
+ k = self.k(h_)
220
+ v = self.v(h_)
221
+
222
+ # compute attention
223
+ b, c, h, w = q.shape
224
+ q = q.reshape(b, c, h * w)
225
+ q = q.permute(0, 2, 1)
226
+ k = k.reshape(b, c, h * w)
227
+ w_ = torch.bmm(q, k)
228
+ w_ = w_ * (int(c) ** (-0.5))
229
+ w_ = torch.nn.functional.softmax(w_, dim=2)
230
+
231
+ # attend to values
232
+ v = v.reshape(b, c, h * w)
233
+ w_ = w_.permute(0, 2, 1)
234
+ h_ = torch.bmm(v, w_)
235
+ h_ = h_.reshape(b, c, h, w)
236
+
237
+ h_ = self.proj_out(h_)
238
+
239
+ return x + h_
240
+
241
+
242
+ class Encoder(nn.Module):
243
+ def __init__(
244
+ self,
245
+ *,
246
+ ch,
247
+ out_ch,
248
+ ch_mult=(1, 2, 4, 8),
249
+ num_res_blocks,
250
+ attn_resolutions,
251
+ dropout=0.0,
252
+ resamp_with_conv=True,
253
+ in_channels,
254
+ resolution,
255
+ z_channels,
256
+ double_z=True,
257
+ **ignore_kwargs,
258
+ ):
259
+ super().__init__()
260
+ self.ch = ch
261
+ self.temb_ch = 0
262
+ self.num_resolutions = len(ch_mult)
263
+ self.num_res_blocks = num_res_blocks
264
+ self.resolution = resolution
265
+ self.in_channels = in_channels
266
+
267
+ # downsampling
268
+ self.conv_in = torch.nn.Conv2d(
269
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
270
+ )
271
+
272
+ curr_res = resolution
273
+ in_ch_mult = (1,) + tuple(ch_mult)
274
+ self.down = nn.ModuleList()
275
+ for i_level in range(self.num_resolutions):
276
+ block = nn.ModuleList()
277
+ attn = nn.ModuleList()
278
+ block_in = ch * in_ch_mult[i_level]
279
+ block_out = ch * ch_mult[i_level]
280
+ for i_block in range(self.num_res_blocks):
281
+ block.append(
282
+ ResnetBlock(
283
+ in_channels=block_in,
284
+ out_channels=block_out,
285
+ temb_channels=self.temb_ch,
286
+ dropout=dropout,
287
+ )
288
+ )
289
+ block_in = block_out
290
+ if curr_res in attn_resolutions:
291
+ attn.append(AttnBlock(block_in))
292
+ down = nn.Module()
293
+ down.block = block
294
+ down.attn = attn
295
+ if i_level != self.num_resolutions - 1:
296
+ down.downsample = Downsample(block_in, resamp_with_conv)
297
+ curr_res = curr_res // 2
298
+ self.down.append(down)
299
+
300
+ # middle
301
+ self.mid = nn.Module()
302
+ self.mid.block_1 = ResnetBlock(
303
+ in_channels=block_in,
304
+ out_channels=block_in,
305
+ temb_channels=self.temb_ch,
306
+ dropout=dropout,
307
+ )
308
+ self.mid.attn_1 = AttnBlock(block_in)
309
+ self.mid.block_2 = ResnetBlock(
310
+ in_channels=block_in,
311
+ out_channels=block_in,
312
+ temb_channels=self.temb_ch,
313
+ dropout=dropout,
314
+ )
315
+
316
+ # end
317
+ self.norm_out = Normalize(block_in)
318
+ self.conv_out = torch.nn.Conv2d(
319
+ block_in,
320
+ 2 * z_channels if double_z else z_channels,
321
+ kernel_size=3,
322
+ stride=1,
323
+ padding=1,
324
+ )
325
+
326
+ def forward(self, x):
327
+ temb = None
328
+
329
+ # downsampling
330
+ hs = [self.conv_in(x)]
331
+ for i_level in range(self.num_resolutions):
332
+ for i_block in range(self.num_res_blocks):
333
+ h = self.down[i_level].block[i_block](hs[-1], temb)
334
+ if len(self.down[i_level].attn) > 0:
335
+ h = self.down[i_level].attn[i_block](h)
336
+ hs.append(h)
337
+ if i_level != self.num_resolutions - 1:
338
+ hs.append(self.down[i_level].downsample(hs[-1]))
339
+
340
+ # middle
341
+ h = hs[-1]
342
+ h = self.mid.block_1(h, temb)
343
+ h = self.mid.attn_1(h)
344
+ h = self.mid.block_2(h, temb)
345
+
346
+ # end
347
+ h = self.norm_out(h)
348
+ h = nonlinearity(h)
349
+ h = self.conv_out(h)
350
+ return h
351
+
352
+
353
+ class Decoder(nn.Module):
354
+ def __init__(
355
+ self,
356
+ *,
357
+ ch,
358
+ out_ch,
359
+ ch_mult=(1, 2, 4, 8),
360
+ num_res_blocks,
361
+ attn_resolutions,
362
+ dropout=0.0,
363
+ resamp_with_conv=True,
364
+ in_channels,
365
+ resolution,
366
+ z_channels,
367
+ give_pre_end=False,
368
+ zq_ch=None,
369
+ add_conv=False,
370
+ **ignorekwargs,
371
+ ):
372
+ super().__init__()
373
+ self.ch = ch
374
+ self.temb_ch = 0
375
+ self.num_resolutions = len(ch_mult)
376
+ self.num_res_blocks = num_res_blocks
377
+ self.resolution = resolution
378
+ self.in_channels = in_channels
379
+ self.give_pre_end = give_pre_end
380
+
381
+ block_in = ch * ch_mult[self.num_resolutions - 1]
382
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
383
+ self.z_shape = (1, z_channels, curr_res, curr_res)
384
+
385
+ # z to block_in
386
+ self.conv_in = torch.nn.Conv2d(
387
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
388
+ )
389
+
390
+ # middle
391
+ self.mid = nn.Module()
392
+ self.mid.block_1 = ResnetBlock(
393
+ in_channels=block_in,
394
+ out_channels=block_in,
395
+ temb_channels=self.temb_ch,
396
+ dropout=dropout,
397
+ zq_ch=zq_ch,
398
+ add_conv=add_conv,
399
+ )
400
+ self.mid.attn_1 = AttnBlock(block_in, zq_ch, add_conv=add_conv)
401
+ self.mid.block_2 = ResnetBlock(
402
+ in_channels=block_in,
403
+ out_channels=block_in,
404
+ temb_channels=self.temb_ch,
405
+ dropout=dropout,
406
+ zq_ch=zq_ch,
407
+ add_conv=add_conv,
408
+ )
409
+
410
+ # upsampling
411
+ self.up = nn.ModuleList()
412
+ for i_level in reversed(range(self.num_resolutions)):
413
+ block = nn.ModuleList()
414
+ attn = nn.ModuleList()
415
+ block_out = ch * ch_mult[i_level]
416
+ for _ in range(self.num_res_blocks + 1):
417
+ block.append(
418
+ ResnetBlock(
419
+ in_channels=block_in,
420
+ out_channels=block_out,
421
+ temb_channels=self.temb_ch,
422
+ dropout=dropout,
423
+ zq_ch=zq_ch,
424
+ add_conv=add_conv,
425
+ )
426
+ )
427
+ block_in = block_out
428
+ if curr_res in attn_resolutions:
429
+ attn.append(AttnBlock(block_in, zq_ch, add_conv=add_conv))
430
+ up = nn.Module()
431
+ up.block = block
432
+ up.attn = attn
433
+ if i_level != 0:
434
+ up.upsample = Upsample(block_in, resamp_with_conv)
435
+ curr_res = curr_res * 2
436
+ self.up.insert(0, up)
437
+
438
+ # end
439
+ self.norm_out = Normalize(block_in, zq_ch, add_conv=add_conv)
440
+ self.conv_out = torch.nn.Conv2d(
441
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
442
+ )
443
+
444
+ def forward(self, z, zq):
445
+ self.last_z_shape = z.shape
446
+ temb = None
447
+
448
+ h = self.conv_in(z)
449
+
450
+ # middle
451
+ h = self.mid.block_1(h, temb, zq)
452
+ h = self.mid.attn_1(h, zq)
453
+ h = self.mid.block_2(h, temb, zq)
454
+
455
+ # upsampling
456
+ for i_level in reversed(range(self.num_resolutions)):
457
+ for i_block in range(self.num_res_blocks + 1):
458
+ h = self.up[i_level].block[i_block](h, temb, zq)
459
+ if len(self.up[i_level].attn) > 0:
460
+ h = self.up[i_level].attn[i_block](h, zq)
461
+ if i_level != 0:
462
+ h = self.up[i_level].upsample(h)
463
+
464
+ # end
465
+ if self.give_pre_end:
466
+ return h
467
+
468
+ h = self.norm_out(h, zq)
469
+ h = nonlinearity(h)
470
+ h = self.conv_out(h)
471
+ return h
472
+
473
+
474
+ # Modified from MoVQ in https://github.com/ai-forever/Kandinsky-3/blob/main/kandinsky3/movq.py
475
+ class MoVQ(nn.Module):
476
+ def __init__(self, generator_params: dict):
477
+ super().__init__()
478
+ z_channels = generator_params["z_channels"]
479
+ self.config = SimpleNamespace(**generator_params)
480
+ self.encoder = Encoder(**generator_params)
481
+ self.quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
482
+ self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
483
+ self.decoder = Decoder(zq_ch=z_channels, **generator_params)
484
+ self.dtype = None
485
+ self.device = None
486
+
487
+ @staticmethod
488
+ def get_model_config(pretrained_model_name_or_path, subfolder):
489
+ config_path = os.path.join(
490
+ pretrained_model_name_or_path, subfolder, "config.json"
491
+ )
492
+ assert os.path.exists(config_path), "config file not exists."
493
+ with open(config_path, "r") as f:
494
+ config = json.loads(f.read())
495
+ return config
496
+
497
+ @classmethod
498
+ def from_pretrained(
499
+ cls,
500
+ pretrained_model_name_or_path,
501
+ subfolder="",
502
+ torch_dtype=torch.float32,
503
+ ):
504
+ config = cls.get_model_config(pretrained_model_name_or_path, subfolder)
505
+ model = cls(generator_params=config)
506
+ ckpt_path = os.path.join(
507
+ pretrained_model_name_or_path, subfolder, "movq_model.safetensors"
508
+ )
509
+ assert os.path.exists(
510
+ ckpt_path
511
+ ), f"ckpt path not exists, please check {ckpt_path}"
512
+ assert torch_dtype != torch.float16, "torch_dtype doesn't support fp16"
513
+ ckpt_weight = load_file(ckpt_path)
514
+ model.load_state_dict(ckpt_weight, strict=True)
515
+ model.to(dtype=torch_dtype)
516
+ return model
517
+
518
+ def to(self, *args, **kwargs):
519
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
520
+ *args, **kwargs
521
+ )
522
+ super(MoVQ, self).to(*args, **kwargs)
523
+ self.dtype = dtype if dtype is not None else self.dtype
524
+ self.device = device if device is not None else self.device
525
+ return self
526
+
527
+ @torch.no_grad()
528
+ @apply_forward_hook
529
+ def encode(self, x):
530
+ h = self.encoder(x)
531
+ h = self.quant_conv(h)
532
+ return h
533
+
534
+ @torch.no_grad()
535
+ @apply_forward_hook
536
+ def decode(self, quant):
537
+ decoder_input = self.post_quant_conv(quant)
538
+ decoded = self.decoder(decoder_input, quant)
539
+ return decoded