File size: 7,026 Bytes
cef1afc c811a04 86b1a7e 325137b eec4cb2 325137b cef1afc 325137b cef1afc 325137b cef1afc 325137b cef1afc eec4cb2 cef1afc eec4cb2 cef1afc 325137b cef1afc 325137b cef1afc 325137b cef1afc 325137b cef1afc 325137b cef1afc 325137b 86b1a7e cef1afc 325137b cef1afc 325137b cef1afc 325137b cef1afc 325137b cef1afc 325137b cef1afc 325137b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import torch
from diffusers import AutoencoderKL
from einops import rearrange
from torch import Tensor
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None
def vae_encode(
media_items: Tensor,
vae: AutoencoderKL,
split_size: int = 1,
vae_per_channel_normalize=False,
) -> Tensor:
"""
Encodes media items (images or videos) into latent representations using a specified VAE model.
The function supports processing batches of images or video frames and can handle the processing
in smaller sub-batches if needed.
Args:
media_items (Tensor): A torch Tensor containing the media items to encode. The expected
shape is (batch_size, channels, height, width) for images or (batch_size, channels,
frames, height, width) for videos.
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
pre-configured and loaded with the appropriate model weights.
split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
If set to more than 1, the input media items are processed in smaller batches according to
this value. Defaults to 1, which processes all items in a single batch.
Returns:
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
to match the input shape, scaled by the model's configuration.
Examples:
>>> import torch
>>> from diffusers import AutoencoderKL
>>> vae = AutoencoderKL.from_pretrained('your-model-name')
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
>>> latents = vae_encode(images, vae)
>>> print(latents.shape) # Output shape will depend on the model's latent configuration.
Note:
In case of a video, the function encodes the media item frame-by frame.
"""
is_video_shaped = media_items.dim() == 5
batch_size, channels = media_items.shape[0:2]
if channels != 3:
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
if split_size > 1:
if len(media_items) % split_size != 0:
raise ValueError(
"Error: The batch size must be divisible by 'train.vae_bs_split"
)
encode_bs = len(media_items) // split_size
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
latents = []
if media_items.device.type == "xla":
xm.mark_step()
for image_batch in media_items.split(encode_bs):
latents.append(vae.encode(image_batch).latent_dist.sample())
if media_items.device.type == "xla":
xm.mark_step()
latents = torch.cat(latents, dim=0)
else:
latents = vae.encode(media_items).latent_dist.sample()
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
return latents
def vae_decode(
latents: Tensor,
vae: AutoencoderKL,
is_video: bool = True,
split_size: int = 1,
vae_per_channel_normalize=False,
) -> Tensor:
is_video_shaped = latents.dim() == 5
batch_size = latents.shape[0]
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
latents = rearrange(latents, "b c n h w -> (b n) c h w")
if split_size > 1:
if len(latents) % split_size != 0:
raise ValueError(
"Error: The batch size must be divisible by 'train.vae_bs_split"
)
encode_bs = len(latents) // split_size
image_batch = [
_run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
for latent_batch in latents.split(encode_bs)
]
images = torch.cat(image_batch, dim=0)
else:
images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
if is_video_shaped and not isinstance(
vae, (VideoAutoencoder, CausalVideoAutoencoder)
):
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
return images
def _run_decoder(
latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False
) -> Tensor:
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
*_, fl, hl, wl = latents.shape
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
latents = latents.to(vae.dtype)
image = vae.decode(
un_normalize_latents(latents, vae, vae_per_channel_normalize),
return_dict=False,
target_shape=(
1,
3,
fl * temporal_scale if is_video else 1,
hl * spatial_scale,
wl * spatial_scale,
),
)[0]
else:
image = vae.decode(
un_normalize_latents(latents, vae, vae_per_channel_normalize),
return_dict=False,
)[0]
return image
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
if isinstance(vae, CausalVideoAutoencoder):
spatial = vae.spatial_downscale_factor
temporal = vae.temporal_downscale_factor
else:
down_blocks = len(
[
block
for block in vae.encoder.down_blocks
if isinstance(block.downsample, Downsample3D)
]
)
spatial = vae.config.patch_size * 2**down_blocks
temporal = (
vae.config.patch_size_t * 2**down_blocks
if isinstance(vae, VideoAutoencoder)
else 1
)
return (temporal, spatial, spatial)
def normalize_latents(
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
return (
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
if vae_per_channel_normalize
else latents * vae.config.scaling_factor
)
def un_normalize_latents(
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
) -> Tensor:
return (
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
if vae_per_channel_normalize
else latents / vae.config.scaling_factor
)
|