Merge pull request #36 from LightricksResearch/bugfix/check_timesteps_eligibility
Browse files
xora/models/autoencoders/vae.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from typing import Optional, Union
|
2 |
|
3 |
import torch
|
|
|
4 |
import math
|
5 |
import torch.nn as nn
|
6 |
from diffusers import ConfigMixin, ModelMixin
|
@@ -60,6 +61,8 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
60 |
self.dims = dims
|
61 |
self.z_sample_size = 1
|
62 |
|
|
|
|
|
63 |
# only relevant if vae tiling is enabled
|
64 |
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
|
65 |
|
@@ -257,7 +260,10 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
257 |
timesteps: Optional[torch.Tensor] = None,
|
258 |
) -> Union[DecoderOutput, torch.FloatTensor]:
|
259 |
z = self.post_quant_conv(z)
|
260 |
-
|
|
|
|
|
|
|
261 |
return dec
|
262 |
|
263 |
def decode(
|
|
|
1 |
from typing import Optional, Union
|
2 |
|
3 |
import torch
|
4 |
+
import inspect
|
5 |
import math
|
6 |
import torch.nn as nn
|
7 |
from diffusers import ConfigMixin, ModelMixin
|
|
|
61 |
self.dims = dims
|
62 |
self.z_sample_size = 1
|
63 |
|
64 |
+
self.decoder_params = inspect.signature(self.decoder.forward).parameters
|
65 |
+
|
66 |
# only relevant if vae tiling is enabled
|
67 |
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
|
68 |
|
|
|
260 |
timesteps: Optional[torch.Tensor] = None,
|
261 |
) -> Union[DecoderOutput, torch.FloatTensor]:
|
262 |
z = self.post_quant_conv(z)
|
263 |
+
if "timesteps" in self.decoder_params:
|
264 |
+
dec = self.decoder(z, target_shape=target_shape, timesteps=timesteps)
|
265 |
+
else:
|
266 |
+
dec = self.decoder(z, target_shape=target_shape)
|
267 |
return dec
|
268 |
|
269 |
def decode(
|