duzx16
commited on
Commit
•
fc442f7
1
Parent(s):
5fe53eb
Fix gradient checkpointing and prefix prompt
Browse files- modeling_chatglm.py +4 -4
modeling_chatglm.py
CHANGED
@@ -406,11 +406,11 @@ class SelfAttention(torch.nn.Module):
|
|
406 |
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
407 |
|
408 |
# adjust key and value for inference
|
|
|
|
|
|
|
|
|
409 |
if use_cache:
|
410 |
-
if kv_cache is not None:
|
411 |
-
cache_k, cache_v = kv_cache
|
412 |
-
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
413 |
-
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
414 |
kv_cache = (key_layer, value_layer)
|
415 |
else:
|
416 |
kv_cache = None
|
|
|
406 |
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
407 |
|
408 |
# adjust key and value for inference
|
409 |
+
if kv_cache is not None:
|
410 |
+
cache_k, cache_v = kv_cache
|
411 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
412 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
413 |
if use_cache:
|
|
|
|
|
|
|
|
|
414 |
kv_cache = (key_layer, value_layer)
|
415 |
else:
|
416 |
kv_cache = None
|