realantonvoronov commited on
Commit
ebf782e
1 Parent(s): 55ca09f

fix rng generator to work in cpu-only environment too

Browse files
Files changed (1) hide show
  1. models/switti.py +1 -1
models/switti.py CHANGED
@@ -71,7 +71,7 @@ class Switti(nn.Module):
71
  self.rope = rope
72
 
73
  self.num_stages_minus_1 = len(self.patch_nums) - 1
74
- self.rng = torch.Generator(device="cuda")
75
 
76
  # 1. input (word) embedding
77
  self.word_embed = nn.Linear(self.Cvae, self.C)
 
71
  self.rope = rope
72
 
73
  self.num_stages_minus_1 = len(self.patch_nums) - 1
74
+ self.rng = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
75
 
76
  # 1. input (word) embedding
77
  self.word_embed = nn.Linear(self.Cvae, self.C)