AmitIsraeli commited on
Commit
bbf15e6
1 Parent(s): c2ee1c9

change mps rng

Browse files
Files changed (1) hide show
  1. models/var.py +1 -1
models/var.py CHANGED
@@ -47,7 +47,7 @@ class VAR(nn.Module):
47
  cur += pn ** 2
48
 
49
  self.num_stages_minus_1 = len(self.patch_nums) - 1
50
- self.rng = torch.Generator(device="mps")
51
 
52
  # 1. input (word) embedding
53
  quant: VectorQuantizer2 = vae_local.quantize
 
47
  cur += pn ** 2
48
 
49
  self.num_stages_minus_1 = len(self.patch_nums) - 1
50
+ self.rng = torch.Generator(device="cpu")
51
 
52
  # 1. input (word) embedding
53
  quant: VectorQuantizer2 = vae_local.quantize