zRzRzRzRzRzRzR commited on
Commit
8ba56be
1 Parent(s): 7031540
Files changed (1) hide show
  1. visual.py +18 -9
visual.py CHANGED
@@ -6,6 +6,7 @@ from transformers.activations import ACT2FN
6
  import math
7
  from torch.nn import LayerNorm
8
 
 
9
  def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
10
  if scaling_attention_score:
11
  query_layer = query_layer / math.sqrt(query_layer.shape[-1])
@@ -16,11 +17,12 @@ def standard_attention(query_layer, key_layer, value_layer, scaling_attention_sc
16
  context_layer = torch.matmul(attention_probs, value_layer)
17
  return context_layer
18
 
 
19
  def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
20
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
21
  # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
22
  attn_output = torch.nn.functional.scaled_dot_product_attention(
23
- query_layer, key_layer, value_layer,
24
  attn_mask=None,
25
  dropout_p=0.,
26
  is_causal=False
@@ -31,10 +33,12 @@ def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_
31
  query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
32
  )
33
 
 
34
  class PatchEmbedding(nn.Module):
35
  def __init__(self, config):
36
  super().__init__()
37
- self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size)
 
38
  self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
39
  self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
40
 
@@ -62,11 +66,11 @@ class Attention(nn.Module):
62
  qkv = self.query_key_value(x)
63
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
64
  q, k, v = qkv[0], qkv[1], qkv[2]
65
-
66
  out = attention_fn_default(
67
  q, k, v
68
  )
69
- output = self.dense(out.transpose(1, 2).reshape(B, L, -1))
70
  output = self.output_dropout(output)
71
  return output
72
 
@@ -105,7 +109,9 @@ class TransformerLayer(nn.Module):
105
  attention_output = self.input_layernorm(self.attention(attention_input))
106
  hidden_states = attention_input + attention_output
107
  mlp_input = hidden_states
108
- mlp_output = self.post_attention_layernorm(self.mlp(mlp_input))
 
 
109
  output = mlp_input + mlp_output
110
  return output
111
 
@@ -147,7 +153,8 @@ class EVA2CLIPModel(nn.Module):
147
  self.patch_embedding = PatchEmbedding(vision_config)
148
  self.transformer = Transformer(vision_config)
149
  self.linear_proj = GLU(config, in_features=config.hidden_size)
150
- self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2, stride=2)
 
151
  self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
152
  self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
153
  self.scaling_factor = vision_config.scaling_factor
@@ -158,14 +165,16 @@ class EVA2CLIPModel(nn.Module):
158
  x = x[:, 1:]
159
 
160
  b, s, h = x.shape
161
- grid_size = int(s**0.5)
162
  x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
163
  x = self.conv(x)
164
 
165
  x = x.flatten(2).transpose(1, 2)
166
  x = self.linear_proj(x)
167
- boi = self.boi.expand(x.shape[0], -1, -1)
168
- eoi = self.eoi.expand(x.shape[0], -1, -1)
 
 
169
  x = torch.cat((boi, x, eoi), dim=1)
170
  x = x / self.scaling_factor
171
  return x
 
6
  import math
7
  from torch.nn import LayerNorm
8
 
9
+
10
  def standard_attention(query_layer, key_layer, value_layer, scaling_attention_score=True):
11
  if scaling_attention_score:
12
  query_layer = query_layer / math.sqrt(query_layer.shape[-1])
 
17
  context_layer = torch.matmul(attention_probs, value_layer)
18
  return context_layer
19
 
20
+
21
  def attention_fn_default(query_layer, key_layer, value_layer, scaling_attention_score=True):
22
  if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score:
23
  # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
24
  attn_output = torch.nn.functional.scaled_dot_product_attention(
25
+ query_layer, key_layer, value_layer,
26
  attn_mask=None,
27
  dropout_p=0.,
28
  is_causal=False
 
33
  query_layer, key_layer, value_layer, scaling_attention_score=scaling_attention_score
34
  )
35
 
36
+
37
  class PatchEmbedding(nn.Module):
38
  def __init__(self, config):
39
  super().__init__()
40
+ self.proj = nn.Conv2d(config.in_channels, config.hidden_size, kernel_size=config.patch_size,
41
+ stride=config.patch_size)
42
  self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
43
  self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)
44
 
 
66
  qkv = self.query_key_value(x)
67
  qkv = qkv.reshape(B, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, H, L, D
68
  q, k, v = qkv[0], qkv[1], qkv[2]
69
+
70
  out = attention_fn_default(
71
  q, k, v
72
  )
73
+ output = self.dense(out.transpose(1, 2).view(B, L, -1))
74
  output = self.output_dropout(output)
75
  return output
76
 
 
109
  attention_output = self.input_layernorm(self.attention(attention_input))
110
  hidden_states = attention_input + attention_output
111
  mlp_input = hidden_states
112
+
113
+ # https://github.com/THUDM/GLM-4/issues/350
114
+ mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)).to(mlp_input.device)
115
  output = mlp_input + mlp_output
116
  return output
117
 
 
153
  self.patch_embedding = PatchEmbedding(vision_config)
154
  self.transformer = Transformer(vision_config)
155
  self.linear_proj = GLU(config, in_features=config.hidden_size)
156
+ self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, out_channels=config.hidden_size, kernel_size=2,
157
+ stride=2)
158
  self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
159
  self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
160
  self.scaling_factor = vision_config.scaling_factor
 
165
  x = x[:, 1:]
166
 
167
  b, s, h = x.shape
168
+ grid_size = int(s ** 0.5)
169
  x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
170
  x = self.conv(x)
171
 
172
  x = x.flatten(2).transpose(1, 2)
173
  x = self.linear_proj(x)
174
+
175
+ # https://github.com/THUDM/GLM-4/issues/350
176
+ boi = self.boi.expand(x.shape[0], -1, -1).to(x.device)
177
+ eoi = self.eoi.expand(x.shape[0], -1, -1).to(x.device)
178
  x = torch.cat((boi, x, eoi), dim=1)
179
  x = x / self.scaling_factor
180
  return x