Merge pull request #11 from LightricksResearch/rm-dist-util
Browse files
xora/models/autoencoders/vae_encode.py
CHANGED
@@ -6,8 +6,10 @@ from torch import Tensor
|
|
6 |
|
7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
9 |
-
|
10 |
-
|
|
|
|
|
11 |
|
12 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
13 |
"""
|
@@ -54,10 +56,12 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
54 |
encode_bs = len(media_items) // split_size
|
55 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
56 |
latents = []
|
57 |
-
|
|
|
58 |
for image_batch in media_items.split(encode_bs):
|
59 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
60 |
-
|
|
|
61 |
latents = torch.cat(latents, dim=0)
|
62 |
else:
|
63 |
latents = vae.encode(media_items).latent_dist.sample()
|
|
|
6 |
|
7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
9 |
+
try:
|
10 |
+
import torch_xla.core.xla_model as xm
|
11 |
+
except:
|
12 |
+
pass
|
13 |
|
14 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
15 |
"""
|
|
|
56 |
encode_bs = len(media_items) // split_size
|
57 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
58 |
latents = []
|
59 |
+
if media_items.device.type == "xla":
|
60 |
+
xm.mark_step()
|
61 |
for image_batch in media_items.split(encode_bs):
|
62 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
63 |
+
if media_items.device.type == "xla":
|
64 |
+
xm.mark_step()
|
65 |
latents = torch.cat(latents, dim=0)
|
66 |
else:
|
67 |
latents = vae.encode(media_items).latent_dist.sample()
|
xora/utils/dist_util.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
|
3 |
-
def execute_graph() -> None:
|
4 |
-
if _acceleration_type == AccelerationType.TPU:
|
5 |
-
xm.mark_step()
|
|
|
|
|
|
|
|
|
|
|
|