daniel shalem
commited on
Commit
•
1940326
1
Parent(s):
645fba0
Feature: Add mixed precision support and direct bfloat16 support.
Browse files
xora/examples/image_to_video.py
CHANGED
@@ -136,6 +136,12 @@ def main():
|
|
136 |
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
137 |
)
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
# Prompts
|
140 |
parser.add_argument(
|
141 |
"--prompt",
|
@@ -224,6 +230,7 @@ def main():
|
|
224 |
is_video=True,
|
225 |
vae_per_channel_normalize=True,
|
226 |
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
|
|
227 |
).images
|
228 |
|
229 |
# Save output video
|
|
|
136 |
"--frame_rate", type=int, default=25, help="Frame rate for the output video"
|
137 |
)
|
138 |
|
139 |
+
parser.add_argument(
|
140 |
+
"--mixed_precision",
|
141 |
+
action="store_true",
|
142 |
+
help="Mixed precision in float32 and bfloat16",
|
143 |
+
)
|
144 |
+
|
145 |
# Prompts
|
146 |
parser.add_argument(
|
147 |
"--prompt",
|
|
|
230 |
is_video=True,
|
231 |
vae_per_channel_normalize=True,
|
232 |
conditioning_method=ConditioningMethod.FIRST_FRAME,
|
233 |
+
mixed_precision=args.mixed_precision,
|
234 |
).images
|
235 |
|
236 |
# Save output video
|
xora/models/transformers/transformer3d.py
CHANGED
@@ -305,7 +305,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
|
|
305 |
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
306 |
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
307 |
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
308 |
-
return cos_freq, sin_freq
|
309 |
|
310 |
def forward(
|
311 |
self,
|
|
|
305 |
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
306 |
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
307 |
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
308 |
+
return cos_freq.to(dtype), sin_freq.to(dtype)
|
309 |
|
310 |
def forward(
|
311 |
self,
|
xora/pipelines/pipeline_xora_video.py
CHANGED
@@ -9,6 +9,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
9 |
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
|
|
12 |
from diffusers.image_processor import VaeImageProcessor
|
13 |
from diffusers.models import AutoencoderKL
|
14 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
@@ -758,6 +759,7 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
758 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
759 |
clean_caption: bool = True,
|
760 |
media_items: Optional[torch.FloatTensor] = None,
|
|
|
761 |
**kwargs,
|
762 |
) -> Union[ImagePipelineOutput, Tuple]:
|
763 |
"""
|
@@ -1006,16 +1008,22 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
1006 |
|
1007 |
if conditioning_mask is not None:
|
1008 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
|
|
|
|
|
|
|
|
|
|
1009 |
|
1010 |
# predict noise model_output
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
|
|
1019 |
|
1020 |
# perform guidance
|
1021 |
if do_classifier_free_guidance:
|
|
|
9 |
|
10 |
import torch
|
11 |
import torch.nn.functional as F
|
12 |
+
from contextlib import nullcontext
|
13 |
from diffusers.image_processor import VaeImageProcessor
|
14 |
from diffusers.models import AutoencoderKL
|
15 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
|
|
759 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
760 |
clean_caption: bool = True,
|
761 |
media_items: Optional[torch.FloatTensor] = None,
|
762 |
+
mixed_precision: bool = False,
|
763 |
**kwargs,
|
764 |
) -> Union[ImagePipelineOutput, Tuple]:
|
765 |
"""
|
|
|
1008 |
|
1009 |
if conditioning_mask is not None:
|
1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
1011 |
+
# Choose the appropriate context manager based on `mixed_precision`
|
1012 |
+
if mixed_precision:
|
1013 |
+
context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
|
1014 |
+
else:
|
1015 |
+
context_manager = nullcontext() # Dummy context manager
|
1016 |
|
1017 |
# predict noise model_output
|
1018 |
+
with context_manager:
|
1019 |
+
noise_pred = self.transformer(
|
1020 |
+
latent_model_input.to(self.transformer.dtype),
|
1021 |
+
indices_grid,
|
1022 |
+
encoder_hidden_states=prompt_embeds.to(self.transformer.dtype),
|
1023 |
+
encoder_attention_mask=prompt_attention_mask,
|
1024 |
+
timestep=current_timestep,
|
1025 |
+
return_dict=False,
|
1026 |
+
)[0]
|
1027 |
|
1028 |
# perform guidance
|
1029 |
if do_classifier_free_guidance:
|