erichardson
commited on
Commit
•
028b6a1
1
Parent(s):
645fba0
VAE: Support retuning intermediate features for 3d perceptual loss
Browse files
xora/models/autoencoders/video_autoencoder.py
CHANGED
@@ -310,7 +310,9 @@ class Encoder(nn.Module):
|
|
310 |
* self.patch_size
|
311 |
)
|
312 |
|
313 |
-
def forward(
|
|
|
|
|
314 |
r"""The forward method of the `Encoder` class."""
|
315 |
|
316 |
downsample_in_time = sample.shape[2] != 1
|
@@ -332,10 +334,14 @@ class Encoder(nn.Module):
|
|
332 |
else lambda x: x
|
333 |
)
|
334 |
|
|
|
|
|
335 |
for down_block in self.down_blocks:
|
336 |
sample = checkpoint_fn(down_block)(
|
337 |
sample, downsample_in_time=downsample_in_time
|
338 |
)
|
|
|
|
|
339 |
|
340 |
sample = checkpoint_fn(self.mid_block)(sample)
|
341 |
|
@@ -363,6 +369,11 @@ class Encoder(nn.Module):
|
|
363 |
else:
|
364 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
365 |
|
|
|
|
|
|
|
|
|
|
|
366 |
return sample
|
367 |
|
368 |
|
|
|
310 |
* self.patch_size
|
311 |
)
|
312 |
|
313 |
+
def forward(
|
314 |
+
self, sample: torch.FloatTensor, return_features=False
|
315 |
+
) -> torch.FloatTensor:
|
316 |
r"""The forward method of the `Encoder` class."""
|
317 |
|
318 |
downsample_in_time = sample.shape[2] != 1
|
|
|
334 |
else lambda x: x
|
335 |
)
|
336 |
|
337 |
+
if return_features:
|
338 |
+
features = []
|
339 |
for down_block in self.down_blocks:
|
340 |
sample = checkpoint_fn(down_block)(
|
341 |
sample, downsample_in_time=downsample_in_time
|
342 |
)
|
343 |
+
if return_features:
|
344 |
+
features.append(sample)
|
345 |
|
346 |
sample = checkpoint_fn(self.mid_block)(sample)
|
347 |
|
|
|
369 |
else:
|
370 |
raise ValueError(f"Invalid input shape: {sample.shape}")
|
371 |
|
372 |
+
if return_features:
|
373 |
+
features.append(
|
374 |
+
sample[:, sample.shape[1] // 2, ...]
|
375 |
+
) # Add the latent means as final feature
|
376 |
+
return sample, features
|
377 |
return sample
|
378 |
|
379 |
|