Update app.py
Browse files
app.py
CHANGED
@@ -94,7 +94,7 @@ def load_vae(vae_dir):
|
|
94 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
95 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
96 |
vae.load_state_dict(vae_state_dict)
|
97 |
-
return vae.to(device
|
98 |
|
99 |
def load_unet(unet_dir):
|
100 |
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
|
@@ -103,7 +103,7 @@ def load_unet(unet_dir):
|
|
103 |
transformer = Transformer3DModel.from_config(transformer_config)
|
104 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
105 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
106 |
-
return transformer.to(device
|
107 |
|
108 |
def load_scheduler(scheduler_dir):
|
109 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|
|
|
94 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
95 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
96 |
vae.load_state_dict(vae_state_dict)
|
97 |
+
return vae.to(device=device, dtype=torch.bfloat16)
|
98 |
|
99 |
def load_unet(unet_dir):
|
100 |
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
|
|
|
103 |
transformer = Transformer3DModel.from_config(transformer_config)
|
104 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
105 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
106 |
+
return transformer.to(device=device, dtype=torch.bfloat16)
|
107 |
|
108 |
def load_scheduler(scheduler_dir):
|
109 |
scheduler_config_path = scheduler_dir / "scheduler_config.json"
|