ZeyuXie commited on
Commit
9a7456a
1 Parent(s): 361d70a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -22,9 +22,9 @@ class dotdict(dict):
22
  class InferRunner:
23
  def __init__(self, device):
24
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
25
- self.vae = AutoencoderKL(**vae_config).to(device)
26
- vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin")
27
  self.vae.load_state_dict(vae_weights)
 
28
 
29
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
30
  self.pico_model = PicoDiffusion(
 
22
  class InferRunner:
23
  def __init__(self, device):
24
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
25
+ vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location="cpu")
 
26
  self.vae.load_state_dict(vae_weights)
27
+ self.vae = AutoencoderKL(**vae_config).to(device)
28
 
29
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
30
  self.pico_model = PicoDiffusion(