lllyasviel commited on
Commit
20f349b
·
1 Parent(s): a396caf
sgm/modules/diffusionmodules/util.py CHANGED
@@ -272,8 +272,7 @@ class SiLU(nn.Module):
272
 
273
  class GroupNorm32(nn.GroupNorm):
274
  def forward(self, x):
275
- self.weight = self.weight.float()
276
- self.bias = self.bias.float()
277
  return super().forward(x.float()).type(x.dtype)
278
 
279
 
 
272
 
273
  class GroupNorm32(nn.GroupNorm):
274
  def forward(self, x):
275
+ self.to(torch.float32)
 
276
  return super().forward(x.float()).type(x.dtype)
277
 
278