File size: 9,949 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
# William Peebles and Saining Xie
#
# Copyright (c) 2021 OpenAI
# MIT License
#
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
from ...models import AutoencoderKL, DiTTransformer2DModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DiTPipeline(DiffusionPipeline):
r"""
Pipeline for image generation based on a Transformer backbone instead of a UNet.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Parameters:
transformer ([`DiTTransformer2DModel`]):
A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
scheduler ([`DDIMScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
model_cpu_offload_seq = "transformer->vae"
def __init__(
self,
transformer: DiTTransformer2DModel,
vae: AutoencoderKL,
scheduler: KarrasDiffusionSchedulers,
id2label: Optional[Dict[int, str]] = None,
):
super().__init__()
self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
# create a imagenet -> id dictionary for easier use
self.labels = {}
if id2label is not None:
for key, value in id2label.items():
for label in value.split(","):
self.labels[label.lstrip().rstrip()] = int(key)
self.labels = dict(sorted(self.labels.items()))
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
r"""
Map label strings from ImageNet to corresponding class ids.
Parameters:
label (`str` or `dict` of `str`):
Label strings to be mapped to class ids.
Returns:
`list` of `int`:
Class ids to be processed by pipeline.
"""
if not isinstance(label, list):
label = list(label)
for l in label:
if l not in self.labels:
raise ValueError(
f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}."
)
return [self.labels[l] for l in label]
@torch.no_grad()
def __call__(
self,
class_labels: List[int],
guidance_scale: float = 4.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
The call function to the pipeline for generation.
Args:
class_labels (List[int]):
List of ImageNet class labels for the images to be generated.
guidance_scale (`float`, *optional*, defaults to 4.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
num_inference_steps (`int`, *optional*, defaults to 250):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
Examples:
```py
>>> from diffusers import DiTPipeline, DPMSolverMultistepScheduler
>>> import torch
>>> pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
>>> pipe = pipe.to("cuda")
>>> # pick words from Imagenet class labels
>>> pipe.labels # to print all available words
>>> # pick words that exist in ImageNet
>>> words = ["white shark", "umbrella"]
>>> class_ids = pipe.get_label_ids(words)
>>> generator = torch.manual_seed(33)
>>> output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator)
>>> image = output.images[0] # label 'white shark'
```
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
batch_size = len(class_labels)
latent_size = self.transformer.config.sample_size
latent_channels = self.transformer.config.in_channels
latents = randn_tensor(
shape=(batch_size, latent_channels, latent_size, latent_size),
generator=generator,
device=self._execution_device,
dtype=self.transformer.dtype,
)
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1)
class_null = torch.tensor([1000] * batch_size, device=self._execution_device)
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1:
half = latent_model_input[: len(latent_model_input) // 2]
latent_model_input = torch.cat([half, half], dim=0)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timesteps = t
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input, timestep=timesteps, class_labels=class_labels_input
).sample
# perform guidance
if guidance_scale > 1:
eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
noise_pred = torch.cat([eps, rest], dim=1)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
else:
model_output = noise_pred
# compute previous image: x_t -> x_t-1
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
if guidance_scale > 1:
latents, _ = latent_model_input.chunk(2, dim=0)
else:
latents = latent_model_input
latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample
samples = (samples / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
samples = self.numpy_to_pil(samples)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (samples,)
return ImagePipelineOutput(images=samples)
|