duzx16
commited on
Commit
•
5fe53eb
1
Parent(s):
74d61a6
Fix checkpointing
Browse files- modeling_chatglm.py +10 -6
modeling_chatglm.py
CHANGED
@@ -63,7 +63,7 @@ class PrefixEncoder(torch.nn.Module):
|
|
63 |
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
64 |
"""
|
65 |
|
66 |
-
def __init__(self, config):
|
67 |
super().__init__()
|
68 |
self.prefix_projection = config.prefix_projection
|
69 |
if self.prefix_projection:
|
@@ -75,7 +75,8 @@ class PrefixEncoder(torch.nn.Module):
|
|
75 |
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
76 |
)
|
77 |
else:
|
78 |
-
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
|
|
79 |
|
80 |
def forward(self, prefix: torch.Tensor):
|
81 |
if self.prefix_projection:
|
@@ -629,8 +630,8 @@ class GLMTransformer(torch.nn.Module):
|
|
629 |
hidden_states,
|
630 |
attention_mask,
|
631 |
rotary_pos_emb,
|
632 |
-
|
633 |
-
use_cache
|
634 |
)
|
635 |
else:
|
636 |
layer_ret = layer(
|
@@ -737,6 +738,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
737 |
if device is not None:
|
738 |
init_kwargs["device"] = device
|
739 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
|
|
|
|
|
|
740 |
|
741 |
# Rotary positional embeddings
|
742 |
self.seq_length = config.seq_length
|
@@ -768,8 +772,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
768 |
batch_size,
|
769 |
self.pre_seq_len,
|
770 |
self.num_layers * 2,
|
771 |
-
self.
|
772 |
-
self.
|
773 |
)
|
774 |
# seq_len, b, nh, hidden_size
|
775 |
past_key_values = self.dropout(past_key_values)
|
|
|
63 |
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
64 |
"""
|
65 |
|
66 |
+
def __init__(self, config: ChatGLMConfig):
|
67 |
super().__init__()
|
68 |
self.prefix_projection = config.prefix_projection
|
69 |
if self.prefix_projection:
|
|
|
75 |
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
76 |
)
|
77 |
else:
|
78 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
79 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
80 |
|
81 |
def forward(self, prefix: torch.Tensor):
|
82 |
if self.prefix_projection:
|
|
|
630 |
hidden_states,
|
631 |
attention_mask,
|
632 |
rotary_pos_emb,
|
633 |
+
kv_caches[index],
|
634 |
+
use_cache
|
635 |
)
|
636 |
else:
|
637 |
layer_ret = layer(
|
|
|
738 |
if device is not None:
|
739 |
init_kwargs["device"] = device
|
740 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
741 |
+
self.num_layers = config.num_layers
|
742 |
+
self.multi_query_group_num = config.multi_query_group_num
|
743 |
+
self.kv_channels = config.kv_channels
|
744 |
|
745 |
# Rotary positional embeddings
|
746 |
self.seq_length = config.seq_length
|
|
|
772 |
batch_size,
|
773 |
self.pre_seq_len,
|
774 |
self.num_layers * 2,
|
775 |
+
self.multi_query_group_num,
|
776 |
+
self.kv_channels
|
777 |
)
|
778 |
# seq_len, b, nh, hidden_size
|
779 |
past_key_values = self.dropout(past_key_values)
|