ford442 commited on
Commit
3632a8e
1 Parent(s): c5300ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -94,7 +94,7 @@ def load_vae(vae_dir):
94
  vae = CausalVideoAutoencoder.from_config(vae_config)
95
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
96
  vae.load_state_dict(vae_state_dict)
97
- return vae.to(device).to(torch.bfloat16)
98
 
99
  def load_unet(unet_dir):
100
  unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
@@ -103,7 +103,7 @@ def load_unet(unet_dir):
103
  transformer = Transformer3DModel.from_config(transformer_config)
104
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
105
  transformer.load_state_dict(unet_state_dict, strict=True)
106
- return transformer.to(device).to(torch.bfloat16)
107
 
108
  def load_scheduler(scheduler_dir):
109
  scheduler_config_path = scheduler_dir / "scheduler_config.json"
 
94
  vae = CausalVideoAutoencoder.from_config(vae_config)
95
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
96
  vae.load_state_dict(vae_state_dict)
97
+ return vae.to(device=device, dtype=torch.bfloat16)
98
 
99
  def load_unet(unet_dir):
100
  unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
 
103
  transformer = Transformer3DModel.from_config(transformer_config)
104
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
105
  transformer.load_state_dict(unet_state_dict, strict=True)
106
+ return transformer.to(device=device, dtype=torch.bfloat16)
107
 
108
  def load_scheduler(scheduler_dir):
109
  scheduler_config_path = scheduler_dir / "scheduler_config.json"