NIRVANALAN commited on
Commit
829eca9
·
1 Parent(s): 89337c5

update dep

Browse files
dit/dit_models_xformers.py CHANGED
@@ -29,7 +29,8 @@ try:
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
  from torch.nn import LayerNorm
32
- from torch.nn import RMSNorm # requires torch2.4
 
33
 
34
  # from torch.nn import LayerNorm
35
  # from xformers import triton
 
29
  from apex.normalization import FusedLayerNorm as LayerNorm
30
  except:
31
  from torch.nn import LayerNorm
32
+ # from torch.nn import RMSNorm # requires torch2.4
33
+ from dit.norm import RMSNorm
34
 
35
  # from torch.nn import LayerNorm
36
  # from xformers import triton
dit/norm.py CHANGED
@@ -6,10 +6,10 @@ def rms_norm(x, weight=None, eps=1e-05):
6
 
7
  class RMSNorm(torch.nn.Module):
8
 
9
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
10
  super().__init__()
11
  self.eps = eps
12
- if weight:
13
  self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
14
  else:
15
  self.register_parameter('weight', None)
 
6
 
7
  class RMSNorm(torch.nn.Module):
8
 
9
+ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=None, device=None):
10
  super().__init__()
11
  self.eps = eps
12
+ if elementwise_affine:
13
  self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
14
  else:
15
  self.register_parameter('weight', None)
ldm/modules/attention.py CHANGED
@@ -17,7 +17,7 @@ from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
17
  try:
18
  from apex.normalization import FusedRMSNorm as RMSNorm
19
  except:
20
- from torch.nn import RMSNorm # requires torch2.4
21
 
22
 
23
  def exists(val):
 
17
  try:
18
  from apex.normalization import FusedRMSNorm as RMSNorm
19
  except:
20
+ from dit.norm import RMSNorm
21
 
22
 
23
  def exists(val):
vit/vision_transformer.py CHANGED
@@ -34,7 +34,11 @@ from .utils import trunc_normal_
34
 
35
  from pdb import set_trace as st
36
  # import apex
37
- from apex.normalization import FusedRMSNorm as RMSNorm
 
 
 
 
38
  # from apex.normalization import FusedLayerNorm as LayerNorm
39
 
40
  try:
 
34
 
35
  from pdb import set_trace as st
36
  # import apex
37
+ try:
38
+ from apex.normalization import FusedRMSNorm as RMSNorm
39
+ except:
40
+ from dit.norm import RMSNorm
41
+
42
  # from apex.normalization import FusedLayerNorm as LayerNorm
43
 
44
  try: