Update modeling_chatglm.py
Browse files- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -255,7 +255,7 @@ class CoreAttention(torch.nn.Module):
|
|
255 |
# [sk, b, np, hn] --> [b, np, sq, hn]
|
256 |
|
257 |
# context layer shape: [b, np, sq, hn]
|
258 |
-
output_size = (value_layer.size(
|
259 |
# change view [b * np, sk, hn]
|
260 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
261 |
# change view [b * np, sq, sk]
|
|
|
255 |
# [sk, b, np, hn] --> [b, np, sq, hn]
|
256 |
|
257 |
# context layer shape: [b, np, sq, hn]
|
258 |
+
output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
|
259 |
# change view [b * np, sk, hn]
|
260 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
261 |
# change view [b * np, sq, sk]
|